| from faster_whisper import WhisperModel |
| from fastapi import FastAPI, UploadFile, File, Form |
| import asyncio |
| import uvicorn |
| import io |
| from typing import List |
|
|
| app = FastAPI() |
|
|
| |
| MAX_BATCH_SIZE = 10 |
| QUEUE_TIMEOUT = 0.1 |
| MODEL_WORKERS = 2 |
|
|
| |
| models = [ |
| WhisperModel( |
| "base", |
| device="cpu", |
| compute_type="int8", |
| cpu_threads=1 |
| ) |
| for _ in range(MODEL_WORKERS) |
| ] |
|
|
| |
| request_queue: asyncio.Queue = asyncio.Queue() |
|
|
| class TranscriptionRequest: |
| def __init__(self, audio_bytes, lang): |
| self.audio_bytes = audio_bytes |
| self.lang = lang |
| self.future = asyncio.get_event_loop().create_future() |
|
|
|
|
| @app.get("/") |
| def home(): |
| return {"message": "Batch Whisper server running"} |
|
|
|
|
| |
| async def batch_worker(model): |
|
|
| while True: |
|
|
| batch: List[TranscriptionRequest] = [] |
|
|
| |
| req = await request_queue.get() |
| batch.append(req) |
|
|
| |
| try: |
| while len(batch) < MAX_BATCH_SIZE: |
| req = await asyncio.wait_for( |
| request_queue.get(), |
| timeout=QUEUE_TIMEOUT |
| ) |
| batch.append(req) |
| except asyncio.TimeoutError: |
| pass |
|
|
| |
| for req in batch: |
|
|
| audio_buffer = io.BytesIO(req.audio_bytes) |
|
|
| segments, info = model.transcribe( |
| audio_buffer, |
| language=req.lang, |
| beam_size=1, |
| best_of=1, |
| vad_filter=True, |
| condition_on_previous_text=False |
| ) |
|
|
| text = "" |
| for seg in segments: |
| text += seg.text |
|
|
| req.future.set_result(text) |
|
|
|
|
| |
| @app.on_event("startup") |
| async def startup(): |
|
|
| for model in models: |
| asyncio.create_task(batch_worker(model)) |
|
|
|
|
| |
| @app.post("/transcribe") |
| async def transcribe( |
| audio: UploadFile = File(...), |
| lang: str = Form("en") |
| ): |
|
|
| audio_bytes = await audio.read() |
|
|
| req = TranscriptionRequest(audio_bytes, lang) |
|
|
| await request_queue.put(req) |
|
|
| text = await req.future |
|
|
| return {"transcript": text} |
|
|
|
|
| if __name__ == "__main__": |
| uvicorn.run(app, host="0.0.0.0", port=7860) |