#!/usr/bin/env python3 # coding:utf-8 # Copyright (c) 2022 SDCI Co. Ltd (author: veelion) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import json import time import asyncio import argparse import websockets import soundfile as sf import statistics WS_START = json.dumps({ 'signal': 'start', 'nbest': 1, 'continuous_decoding': False, }) WS_END = json.dumps({ 'signal': 'end' }) async def ws_rec(data, ws_uri): begin = time.time() conn = await websockets.connect(ws_uri, ping_timeout=200) # step 1: send start await conn.send(WS_START) ret = await conn.recv() # step 2: send audio data await conn.send(data) # step 3: send end await conn.send(WS_END) # step 4: receive result texts = [] while 1: ret = await conn.recv() ret = json.loads(ret) if ret['type'] == 'final_result': nbest = json.loads(ret['nbest']) text = nbest[0]['sentence'] texts.append(text) elif ret['type'] == 'speech_end': break # step 5: close try: await conn.close() except Exception as e: # this except has no effect, just log as debug # it seems the server does not send close info, maybe print(e) time_cost = time.time() - begin return { 'text': ''.join(texts), 'time': time_cost, } def get_args(): parser = argparse.ArgumentParser(description='') parser.add_argument( '-u', '--ws_uri', required=True, help="websocket_server_main's uri, e.g. ws://127.0.0.1:10086") parser.add_argument( '-w', '--wav_scp', required=True, help='path to wav_scp_file') parser.add_argument( '-t', '--trans', required=True, help='path to trans_text_file of wavs') parser.add_argument( '-s', '--save_to', required=True, help='path to save transcription') parser.add_argument( '-l', '--log', required=True, help='path to save throughput log file') parser.add_argument( '-n', '--num_concurrence', type=int, required=True, help='num of concurrence for query') args = parser.parse_args() return args def print_result(info, log_f): length = max([len(k) for k in info]) for k, v in info.items(): print(f'\t{k: >{length}} : {v}') log_f.write(f'\t{k: >{length}} : {v}\n') async def main(args): wav_scp = [] total_duration = 0 with open(args.wav_scp) as f: for line in f: zz = line.strip().split() assert len(zz) == 2 data, sr = sf.read(zz[1], dtype='int16') assert sr == 16000 duration = (len(data)) / 16000 total_duration += duration wav_scp.append((zz[0], data.tobytes())) print(f'{len(wav_scp) = }, {total_duration = }') tasks = [] failed = 0 texts = [] request_times = [] begin = time.time() for i, (_uttid, data) in enumerate(wav_scp): task = asyncio.create_task(ws_rec(data, args.ws_uri)) tasks.append((_uttid, task)) if len(tasks) < args.num_concurrence: continue print((f'{i=}, start {args.num_concurrence} ' f'queries @ {time.strftime("%m-%d %H:%M:%S")}')) for uttid, task in tasks: result = await task texts.append(f'{uttid}\t{result["text"]}\n') request_times.append(result['time']) tasks = [] print(f'\tdone @ {time.strftime("%m-%d %H:%M:%S")}') if tasks: for uttid, task in tasks: result = await task texts.append(f'{uttid}\t{result["text"]}\n') request_times.append(result['time']) request_time = time.time() - begin rtf = request_time / total_duration with open(args.log, 'w') as log_f: print('For all concurrence:') log_f.write('For all requests: \n') print_result({ 'failed': failed, 'total_duration': total_duration, 'request_time': request_time, 'RTF': rtf, }, log_f) print('For one request:') log_f.write('For one request: \n') print_result({ 'mean': statistics.mean(request_times), 'median': statistics.median(request_times), 'max_time': max(request_times), 'min_time': min(request_times), }, log_f) with open(args.save_to, 'w', encoding='utf8') as fsave: fsave.write(''.join(texts)) # caculate CER cmd = (f'python3 ./compute-wer.py --char=1 --v=1 ' f'{args.trans} {args.save_to} > ' f'{args.save_to}-test-{args.num_concurrence}.cer.txt') print(cmd) os.system(cmd) print('done') if __name__ == '__main__': args = get_args() asyncio.run(main(args))