Spaces:
Running
Running
| import argparse | |
| import asyncio | |
| import functools | |
| import json | |
| import os | |
| from io import BytesIO | |
| import uvicorn | |
| from fastapi import FastAPI, BackgroundTasks, File, Body, UploadFile, Request | |
| from fastapi.responses import StreamingResponse | |
| from faster_whisper import WhisperModel | |
| from starlette.staticfiles import StaticFiles | |
| from starlette.templating import Jinja2Templates | |
| from zhconv import convert | |
| from utils.data_utils import remove_punctuation | |
| from utils.utils import add_arguments, print_arguments | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| add_arg = functools.partial(add_arguments, argparser=parser) | |
| add_arg("host", type=str, default="0.0.0.0", help="") | |
| add_arg("port", type=int, default=5000, help="") | |
| add_arg("model_path", type=str, default="models/sam2ai/whisper-odia-small-finetune-int8-ct2", help="") | |
| add_arg("use_gpu", type=bool, default=False, help="") | |
| add_arg("use_int8", type=bool, default=True, help="") | |
| add_arg("beam_size", type=int, default=10, help="") | |
| add_arg("num_workers", type=int, default=2, help="") | |
| add_arg("vad_filter", type=bool, default=True, help="") | |
| add_arg("local_files_only", type=bool, default=True, help="") | |
| args = parser.parse_args() | |
| print_arguments(args) | |
| # | |
| assert os.path.exists(args.model_path), f"{args.model_path}" | |
| # | |
| if args.use_gpu: | |
| if not args.use_int8: | |
| model = WhisperModel(args.model_path, device="cuda", compute_type="float16", | |
| num_workers=args.num_workers, local_files_only=args.local_files_only) | |
| else: | |
| model = WhisperModel(args.model_path, device="cuda", | |
| compute_type="int8_float16", num_workers=args.num_workers, | |
| local_files_only=args.local_files_only) | |
| else: | |
| model = WhisperModel(args.model_path, device="cpu", | |
| compute_type="int8", num_workers=args.num_workers, | |
| local_files_only=args.local_files_only) | |
| # | |
| # _, _ = model.transcribe("dataset/test.wav", beam_size=5) | |
| app = FastAPI(title="OdiaGenAI Whisper ASR") | |
| app.mount('/static', StaticFiles(directory='static'), name='static') | |
| templates = Jinja2Templates(directory="templates") | |
| model_semaphore = None | |
| def release_model_semaphore(): | |
| model_semaphore.release() | |
| def recognition(file: File, to_simple: int, | |
| remove_pun: int, language: str = "bn", | |
| task: str = "transcribe" | |
| ): | |
| segments, info = model.transcribe(file, beam_size=10, task=task, language=language, vad_filter=args.vad_filter) | |
| for segment in segments: | |
| text = segment.text | |
| if to_simple == 1: | |
| # text = convert(text, '') | |
| pass | |
| if remove_pun == 1: | |
| # text = remove_punctuation(text) | |
| pass | |
| ret = {"result": text, "start": round(segment.start, 2), "end": round(segment.end, 2)} | |
| # | |
| yield json.dumps(ret).encode() + b"\0" | |
| async def api_recognition_stream( | |
| to_simple: int = Body(1, description="", embed=True), | |
| remove_pun: int = Body(0, description="", embed=True), | |
| language: str = Body("bn", description="", embed=True), | |
| task: str = Body("transcribe", description="", embed=True), | |
| audio: UploadFile = File(..., description="") | |
| ): | |
| global model_semaphore | |
| if language == "None": language = None | |
| if model_semaphore is None: | |
| model_semaphore = asyncio.Semaphore(5) | |
| await model_semaphore.acquire() | |
| contents = await audio.read() | |
| data = BytesIO(contents) | |
| generator = recognition( | |
| file=data, to_simple=to_simple, | |
| remove_pun=remove_pun, language=language, | |
| task=task | |
| ) | |
| background_tasks = BackgroundTasks() | |
| background_tasks.add_task(release_model_semaphore) | |
| return StreamingResponse(generator, background=background_tasks) | |
| async def api_recognition( | |
| to_simple: int = Body(1, description="", embed=True), | |
| remove_pun: int = Body(0, description="", embed=True), | |
| language: str = Body("bn", description="", embed=True), | |
| task: str = Body("transcribe", description="", embed=True), | |
| audio: UploadFile = File(..., description="") | |
| ): | |
| if language == "None":language=None | |
| contents = await audio.read() | |
| data = BytesIO(contents) | |
| generator = recognition( | |
| file=data, to_simple=to_simple, | |
| remove_pun=remove_pun, language=language, | |
| task=task | |
| ) | |
| results = [] | |
| for output in generator: | |
| output = json.loads(output[:-1].decode("utf-8")) | |
| results.append(output) | |
| ret = {"results": results, "code": 0} | |
| return ret | |
| async def index(request: Request): | |
| return templates.TemplateResponse( | |
| "index.html", {"request": request, "id": id} | |
| ) | |
| if __name__ == '__main__': | |
| uvicorn.run(app, host=args.host, port=args.port) | |