kiava-v0.0 / app.py
Asme's picture
Update app.py
bb1e151 verified
import os
import sys
import subprocess as sp
import numpy as np
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor
import gradio as gr
from inaSpeechSegmenter import Segmenter
DEVNULL = open(os.devnull, 'w')
model_id = os.getenv('MODEL_REPO_ID')
target_lang = os.getenv('LANG')
token = os.getenv('HF_TOKEN')
processor = AutoProcessor.from_pretrained(model_id, target_lang=target_lang, token=token)
model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang=target_lang, ignore_mismatched_sizes=True, token=token)
seg = Segmenter()
LABELES = set(['male', 'female'])
def ffmpeg_load_audio(filename, sr=16_000, mono=True, normalize=False, in_type=np.int16, out_type=np.float32):
channels = 1 if mono else 2
format_strings = {
np.float64: 'f64le',
np.float32: 'f32le',
np.int16: 's16le',
np.int32: 's32le',
np.uint32: 'u32le'
}
format_string = format_strings[in_type]
command = [
'ffmpeg',
'-i', filename,
'-f', format_string,
'-acodec', 'pcm_' + format_string,
'-ar', str(sr),
'-ac', str(channels),
'-']
p = sp.Popen(command, stdout=sp.PIPE, stderr=DEVNULL, bufsize=4096)
bytes_per_sample = np.dtype(in_type).itemsize
frame_size = bytes_per_sample * channels
chunk_size = frame_size * sr # read in 1-second chunks
raw = b''
with p.stdout as stdout:
while True:
data = stdout.read(chunk_size)
if data:
raw += data
else:
break
audio = np.fromstring(raw, dtype=in_type).astype(out_type)
if channels > 1:
audio = audio.reshape((-1, channels)).transpose()
if audio.size == 0:
return audio, sr
if issubclass(out_type, np.floating):
if normalize:
peak = np.abs(audio).max()
if peak > 0:
audio /= peak
elif issubclass(in_type, np.integer):
audio /= np.iinfo(in_type).max
return sr, audio
def preprocess(fin):
'''preprocessing of uploaded data.'''
# segment audio
segmentation = seg(fin)
#print('segmentation', segmentation)
# convert audio to mono 16khz
sr, audio = ffmpeg_load_audio(fin)
# split audio
for label, start, end in segmentation:
if label not in LABELES:
data = None
#print(f'{label} {start} - {end}', (sr, data))
else:
start_idx = int(sr*start)
end_idx = int(sr*end)
data = audio[start_idx:end_idx]
#print(f'{label} {start} - {end}', (sr, data.shape))
yield (sr, data), (label, start, end)
def predict(fin):
res = []
for (sr, data), (label, start, end) in preprocess(fin):
if data is None:
hyp = '\n'
res.append(hyp)
yield ' '.join(res)
continue
inputs = processor(data, sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
outputs = model(**inputs).logits
pred_ids = torch.argmax(outputs, dim=-1)[0]
hyp = processor.decode(pred_ids)
res.append(hyp)
yield ' '.join(res)
#print(res)
return ' '.join(res)
def main():
demo = gr.Interface(fn=predict,
inputs=[gr.Audio(type='filepath')],
outputs="text")
demo.launch()
if __name__ == '__main__':
main()