| | import os |
| | import pickle |
| | from typing import Union |
| | from fastapi import Request |
| | import torch |
| | from transformers import pipeline |
| |
|
| | from fastapi import FastAPI |
| | from contextlib import asynccontextmanager |
| |
|
| | DEVICE = os.getenv('DEVICE', 'mps') |
| | ATTN_IMPLEMENTATION = os.getenv('ATTN_IMPLEMENTATION', "sdpa") |
| |
|
| | transcribe_pipeline = None |
| |
|
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | transcribe_pipeline = pipeline( |
| | "automatic-speech-recognition", |
| | model="openai/whisper-large-v3", |
| | torch_dtype=torch.float16 if ATTN_IMPLEMENTATION == "sdpa" else torch.bfloat16, |
| | device=DEVICE, |
| | model_kwargs={"attn_implementation": ATTN_IMPLEMENTATION}, |
| | ) |
| | transcribe_pipeline.model.to('cuda') |
| | yield |
| |
|
| | app = FastAPI(lifespan=lifespan) |
| |
|
| |
|
| |
|
| | @app.get("/") |
| | def read_root(): |
| | return {"status": "ok"} |
| |
|
| |
|
| |
|
| | @app.post("/transcribe") |
| | async def transcribe(request: Request): |
| | body = await request.body() |
| | audio_chunk = pickle.loads(body) |
| | outputs = transcribe_pipeline( |
| | audio_chunk, |
| | chunk_length_s=30, |
| | batch_size=24, |
| | generate_kwargs={ |
| | 'task': 'transcribe', |
| | 'language': 'english' |
| | }, |
| | return_timestamps='word' |
| | ) |
| | text = outputs["text"].strip() |
| | return {"transcribe": text} |
| |
|