| from starlette.applications import Starlette |
| from starlette.responses import JSONResponse |
| from starlette.staticfiles import StaticFiles |
| from starlette.middleware.cors import CORSMiddleware |
| from starlette.requests import Request |
| from starlette.templating import Jinja2Templates |
| from starlette.routing import Route |
| from starlette.responses import RedirectResponse |
| import uvicorn |
|
|
| from transformers import pipeline |
| from pydub import AudioSegment |
| import torch |
| import asyncio |
|
|
| device = "cuda:2" if torch.cuda.is_available() else "cpu" |
|
|
| app = Starlette() |
| app.mount("/static", StaticFiles(directory="static"), name="static") |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_headers=["X-Requested-With", "Content-Type"], |
| ) |
|
|
| templates = Jinja2Templates(directory="templates") |
|
|
|
|
| async def homepage(request): |
| return templates.TemplateResponse("index.html", {"request": request}) |
|
|
| async def upload_file(request): |
|
|
| formdata = await request.form() |
| file = formdata["file"] |
| input = await file.read() |
| response_q = asyncio.Queue() |
| await request.app.model_queue.put((input, response_q)) |
| output = await response_q.get() |
| return templates.TemplateResponse( |
| "index.html", |
| {"request": request, "content": output['text']}, |
| ) |
|
|
|
|
| async def server_loop(q): |
| pipe = pipeline( |
| "automatic-speech-recognition", |
| model="openai/whisper-large", |
| chunk_length_s=30, |
| device=device, |
| ) |
|
|
| pipe.model.config.forced_decoder_ids = ( |
| pipe.tokenizer.get_decoder_prompt_ids( |
| language="zh", |
| task="transcribe" |
| ) |
| ) |
| while True: |
| (input, response_q) = await q.get() |
| out = pipe(input) |
| await response_q.put(out) |
|
|
|
|
| app = Starlette( |
| routes=[ |
| Route("/", homepage, methods=["GET"]), |
| Route("/upload", upload_file, methods=["POST"]), |
| ], |
| ) |
|
|
|
|
| @app.on_event("startup") |
| async def startup_event(): |
| q = asyncio.Queue() |
| app.model_queue = q |
| asyncio.create_task(server_loop(q)) |
|
|