Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
import gradio as gr
|
| 3 |
import librosa
|
|
|
|
| 4 |
|
| 5 |
from espnet2.bin.s2t_inference_language import Speech2Language
|
| 6 |
from espnet2.bin.s2t_inference import Speech2Text as ARSpeech2Text
|
|
@@ -75,49 +76,11 @@ Please consider citing the following papers if you find our work helpful.
|
|
| 75 |
# if not torch.cuda.is_available():
|
| 76 |
# raise RuntimeError("Please use GPU for better inference speed.")
|
| 77 |
|
| 78 |
-
|
| 79 |
|
| 80 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 81 |
|
| 82 |
|
| 83 |
-
try:
|
| 84 |
-
s2l = Speech2Language.from_pretrained(
|
| 85 |
-
model_tag=f"espnet/owsm_v4_medium_1B",
|
| 86 |
-
device=device,
|
| 87 |
-
nbest=1,
|
| 88 |
-
)
|
| 89 |
-
except Exception as e:
|
| 90 |
-
print("File downloaded")
|
| 91 |
-
|
| 92 |
-
# 2. Remove unrequired file
|
| 93 |
-
import yaml
|
| 94 |
-
from pathlib import Path
|
| 95 |
-
import espnet_model_zoo
|
| 96 |
-
|
| 97 |
-
d = "models--espnet--owsm_v4_medium_1B/snapshots/471418ddaf0b03c9ab1fd75f1f5d26fc3aea3aa9/exp/s2t_train_conv2d8_size1024_e18_d18_mel128_raw_bpe50000/config.yaml"
|
| 98 |
-
p = Path(espnet_model_zoo.__file__)
|
| 99 |
-
config_path = p.parent / d
|
| 100 |
-
|
| 101 |
-
def remove_key(obj, key="gradient_checkpoint_layers"):
|
| 102 |
-
if isinstance(obj, dict):
|
| 103 |
-
if key in obj:
|
| 104 |
-
del obj[key]
|
| 105 |
-
for k, v in list(obj.items()):
|
| 106 |
-
remove_key(v, key)
|
| 107 |
-
elif isinstance(obj, list):
|
| 108 |
-
for item in obj:
|
| 109 |
-
remove_key(item, key)
|
| 110 |
-
|
| 111 |
-
with open(config_path, "r") as f:
|
| 112 |
-
config = yaml.safe_load(f)
|
| 113 |
-
|
| 114 |
-
remove_key(config)
|
| 115 |
-
|
| 116 |
-
with open(config_path, "w") as f:
|
| 117 |
-
yaml.safe_dump(config, f, sort_keys=False, allow_unicode=True)
|
| 118 |
-
|
| 119 |
-
print("Done! All 'gradient_checkpoint_layers' keys removed.")
|
| 120 |
-
|
| 121 |
s2l = Speech2Language.from_pretrained(
|
| 122 |
model_tag=f"espnet/owsm_v4_medium_1B",
|
| 123 |
device=device,
|
|
@@ -138,12 +101,12 @@ s2t_ar = ARSpeech2Text.from_pretrained(
|
|
| 138 |
|
| 139 |
# CTC looks okay.
|
| 140 |
s2t_ctc = CTCSpeech2Text.from_pretrained(
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
|
| 148 |
|
| 149 |
iso_codes = ['abk', 'afr', 'amh', 'ara', 'asm', 'ast', 'aze', 'bak', 'bas', 'bel', 'ben', 'bos', 'bre', 'bul', 'cat', 'ceb', 'ces', 'chv', 'ckb', 'cmn', 'cnh', 'cym', 'dan', 'deu', 'dgd', 'div', 'ell', 'eng', 'epo', 'est', 'eus', 'fas', 'fil', 'fin', 'fra', 'frr', 'ful', 'gle', 'glg', 'grn', 'guj', 'hat', 'hau', 'heb', 'hin', 'hrv', 'hsb', 'hun', 'hye', 'ibo', 'ina', 'ind', 'isl', 'ita', 'jav', 'jpn', 'kab', 'kam', 'kan', 'kat', 'kaz', 'kea', 'khm', 'kin', 'kir', 'kmr', 'kor', 'lao', 'lav', 'lga', 'lin', 'lit', 'ltz', 'lug', 'luo', 'mal', 'mar', 'mas', 'mdf', 'mhr', 'mkd', 'mlt', 'mon', 'mri', 'mrj', 'mya', 'myv', 'nan', 'nep', 'nld', 'nno', 'nob', 'npi', 'nso', 'nya', 'oci', 'ori', 'orm', 'ory', 'pan', 'pol', 'por', 'pus', 'quy', 'roh', 'ron', 'rus', 'sah', 'sat', 'sin', 'skr', 'slk', 'slv', 'sna', 'snd', 'som', 'sot', 'spa', 'srd', 'srp', 'sun', 'swa', 'swe', 'swh', 'tam', 'tat', 'tel', 'tgk', 'tgl', 'tha', 'tig', 'tir', 'tok', 'tpi', 'tsn', 'tuk', 'tur', 'twi', 'uig', 'ukr', 'umb', 'urd', 'uzb', 'vie', 'vot', 'wol', 'xho', 'yor', 'yue', 'zho', 'zul']
|
|
@@ -187,6 +150,7 @@ def format_timestamp(
|
|
| 187 |
)
|
| 188 |
|
| 189 |
|
|
|
|
| 190 |
def predict(audio_path, src_lang: str, task: str, model_name: str, beam_size, long_form: bool, text_prev: str,):
|
| 191 |
task_sym = f'<{task2code[task]}>'
|
| 192 |
|
|
|
|
| 1 |
import torch
|
| 2 |
import gradio as gr
|
| 3 |
import librosa
|
| 4 |
+
import spaces
|
| 5 |
|
| 6 |
from espnet2.bin.s2t_inference_language import Speech2Language
|
| 7 |
from espnet2.bin.s2t_inference import Speech2Text as ARSpeech2Text
|
|
|
|
| 76 |
# if not torch.cuda.is_available():
|
| 77 |
# raise RuntimeError("Please use GPU for better inference speed.")
|
| 78 |
|
| 79 |
+
device = "cuda"
|
| 80 |
|
| 81 |
+
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 82 |
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
s2l = Speech2Language.from_pretrained(
|
| 85 |
model_tag=f"espnet/owsm_v4_medium_1B",
|
| 86 |
device=device,
|
|
|
|
| 101 |
|
| 102 |
# CTC looks okay.
|
| 103 |
s2t_ctc = CTCSpeech2Text.from_pretrained(
|
| 104 |
+
model_tag=f"espnet/owsm_ctc_v4_1B",
|
| 105 |
+
device=device,
|
| 106 |
+
lang_sym="<eng>",
|
| 107 |
+
task_sym="<asr>",
|
| 108 |
+
predict_time=False,
|
| 109 |
+
)
|
| 110 |
|
| 111 |
|
| 112 |
iso_codes = ['abk', 'afr', 'amh', 'ara', 'asm', 'ast', 'aze', 'bak', 'bas', 'bel', 'ben', 'bos', 'bre', 'bul', 'cat', 'ceb', 'ces', 'chv', 'ckb', 'cmn', 'cnh', 'cym', 'dan', 'deu', 'dgd', 'div', 'ell', 'eng', 'epo', 'est', 'eus', 'fas', 'fil', 'fin', 'fra', 'frr', 'ful', 'gle', 'glg', 'grn', 'guj', 'hat', 'hau', 'heb', 'hin', 'hrv', 'hsb', 'hun', 'hye', 'ibo', 'ina', 'ind', 'isl', 'ita', 'jav', 'jpn', 'kab', 'kam', 'kan', 'kat', 'kaz', 'kea', 'khm', 'kin', 'kir', 'kmr', 'kor', 'lao', 'lav', 'lga', 'lin', 'lit', 'ltz', 'lug', 'luo', 'mal', 'mar', 'mas', 'mdf', 'mhr', 'mkd', 'mlt', 'mon', 'mri', 'mrj', 'mya', 'myv', 'nan', 'nep', 'nld', 'nno', 'nob', 'npi', 'nso', 'nya', 'oci', 'ori', 'orm', 'ory', 'pan', 'pol', 'por', 'pus', 'quy', 'roh', 'ron', 'rus', 'sah', 'sat', 'sin', 'skr', 'slk', 'slv', 'sna', 'snd', 'som', 'sot', 'spa', 'srd', 'srp', 'sun', 'swa', 'swe', 'swh', 'tam', 'tat', 'tel', 'tgk', 'tgl', 'tha', 'tig', 'tir', 'tok', 'tpi', 'tsn', 'tuk', 'tur', 'twi', 'uig', 'ukr', 'umb', 'urd', 'uzb', 'vie', 'vot', 'wol', 'xho', 'yor', 'yue', 'zho', 'zul']
|
|
|
|
| 150 |
)
|
| 151 |
|
| 152 |
|
| 153 |
+
@spaces.GPU
|
| 154 |
def predict(audio_path, src_lang: str, task: str, model_name: str, beam_size, long_form: bool, text_prev: str,):
|
| 155 |
task_sym = f'<{task2code[task]}>'
|
| 156 |
|