asr_demo / upload.py
kcltw's picture
Upload folder using huggingface_hub
53d9e1b
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.staticfiles import StaticFiles
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request
from starlette.templating import Jinja2Templates
from starlette.routing import Route
from starlette.responses import RedirectResponse
import uvicorn
from transformers import pipeline
from pydub import AudioSegment
import torch
import asyncio
device = "cuda:2" if torch.cuda.is_available() else "cpu"
app = Starlette()
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["X-Requested-With", "Content-Type"],
)
templates = Jinja2Templates(directory="templates")
async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})
async def upload_file(request):
formdata = await request.form()
file = formdata["file"]
input = await file.read()
response_q = asyncio.Queue()
await request.app.model_queue.put((input, response_q))
output = await response_q.get()
return templates.TemplateResponse(
"index.html",
{"request": request, "content": output['text']},
)
async def server_loop(q):
pipe = pipeline(
"automatic-speech-recognition",
model="openai/whisper-large",
chunk_length_s=30,
device=device,
)
pipe.model.config.forced_decoder_ids = (
pipe.tokenizer.get_decoder_prompt_ids(
language="zh",
task="transcribe"
)
)
while True:
(input, response_q) = await q.get()
out = pipe(input)
await response_q.put(out)
app = Starlette(
routes=[
Route("/", homepage, methods=["GET"]),
Route("/upload", upload_file, methods=["POST"]),
],
)
@app.on_event("startup")
async def startup_event():
q = asyncio.Queue()
app.model_queue = q
asyncio.create_task(server_loop(q))