wh / app.py
don0726's picture
Update app.py
1f3164c verified
from faster_whisper import WhisperModel
from fastapi import FastAPI, UploadFile, File, Form
import asyncio
import uvicorn
import io
from typing import List
app = FastAPI()
# ===== CONFIGURATION =====
MAX_BATCH_SIZE = 10 # max requests processed together
QUEUE_TIMEOUT = 0.1 # wait time to collect batch
MODEL_WORKERS = 2 # your 2 vCPU
# ===== LOAD MODELS =====
models = [
WhisperModel(
"base",
device="cpu",
compute_type="int8",
cpu_threads=1
)
for _ in range(MODEL_WORKERS)
]
# ===== REQUEST QUEUE =====
request_queue: asyncio.Queue = asyncio.Queue()
class TranscriptionRequest:
def __init__(self, audio_bytes, lang):
self.audio_bytes = audio_bytes
self.lang = lang
self.future = asyncio.get_event_loop().create_future()
@app.get("/")
def home():
return {"message": "Batch Whisper server running"}
# ===== BATCH WORKER =====
async def batch_worker(model):
while True:
batch: List[TranscriptionRequest] = []
# Wait for first request
req = await request_queue.get()
batch.append(req)
# Collect more requests for small time
try:
while len(batch) < MAX_BATCH_SIZE:
req = await asyncio.wait_for(
request_queue.get(),
timeout=QUEUE_TIMEOUT
)
batch.append(req)
except asyncio.TimeoutError:
pass
# Process batch
for req in batch:
audio_buffer = io.BytesIO(req.audio_bytes)
segments, info = model.transcribe(
audio_buffer,
language=req.lang,
beam_size=1,
best_of=1,
vad_filter=True,
condition_on_previous_text=False
)
text = ""
for seg in segments:
text += seg.text
req.future.set_result(text)
# ===== START WORKERS =====
@app.on_event("startup")
async def startup():
for model in models:
asyncio.create_task(batch_worker(model))
# ===== API ENDPOINT =====
@app.post("/transcribe")
async def transcribe(
audio: UploadFile = File(...),
lang: str = Form("en")
):
audio_bytes = await audio.read()
req = TranscriptionRequest(audio_bytes, lang)
await request_queue.put(req)
text = await req.future
return {"transcript": text}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)