Spaces:
Sleeping
Sleeping
File size: 7,105 Bytes
6ad1b35 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | import os
import io
import logging
import tempfile
import threading
import subprocess
import numpy as np
import soundfile as sf
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse
from faster_whisper import WhisperModel
# ─────────────────────────────────────────────
# Logging
# ─────────────────────────────────────────────
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s"
)
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────
# App
# ─────────────────────────────────────────────
app = FastAPI(title="Klas English Transcription API")
# ─────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────
ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small.en")
DEVICE = os.getenv("DEVICE", "cpu") # HF Spaces is CPU by default
COMPUTE_TYPE = os.getenv("COMPUTE_TYPE", "int8") # int8 works well on CPU
# ─────────────────────────────────────────────
# Globals
# ─────────────────────────────────────────────
asr_model = None
model_loaded = False
# ─────────────────────────────────────────────
# Model loading — background thread so the
# container passes HF's health check quickly
# ─────────────────────────────────────────────
def _load_model():
global asr_model, model_loaded
try:
logger.info(f"Loading faster-whisper ({ASR_MODEL_SIZE}) on {DEVICE}/{COMPUTE_TYPE}")
asr_model = WhisperModel(
ASR_MODEL_SIZE,
device=DEVICE,
compute_type=COMPUTE_TYPE,
)
model_loaded = True
logger.info("Model ready ✅")
except Exception as e:
logger.error(f"Model load failed: {e}", exc_info=True)
@app.on_event("startup")
async def startup():
threading.Thread(target=_load_model, daemon=True).start()
# ─────────────────────────────────────────────
# Health
# ─────────────────────────────────────────────
@app.get("/ping")
def ping():
return {"status": "ready" if model_loaded else "initializing"}
# ─────────────────────────────────────────────
# Audio helper
# ─────────────────────────────────────────────
def _load_audio(raw: bytes) -> np.ndarray:
"""Read audio bytes → float32 mono 16 kHz numpy array."""
try:
arr, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
except Exception:
# Fallback: treat as raw float32 PCM at 16 kHz
arr = np.frombuffer(raw, dtype=np.float32)
sr = 16000
# Stereo → mono
if arr.ndim > 1:
arr = arr.mean(axis=1)
# Resample to 16 kHz if needed
if sr != 16000:
logger.info(f"Resampling {sr} Hz → 16000 Hz")
tmp_in = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_out = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
try:
sf.write(tmp_in.name, arr, sr)
subprocess.run(
["ffmpeg", "-y", "-i", tmp_in.name,
"-ar", "16000", "-ac", "1", tmp_out.name],
check=True, capture_output=True,
)
arr, _ = sf.read(tmp_out.name, dtype="float32")
finally:
os.unlink(tmp_in.name)
os.unlink(tmp_out.name)
# Normalize
peak = max(abs(arr.max()), abs(arr.min()), 1e-9)
if peak > 1.0:
arr = arr / peak
return arr
# ─────────────────────────────────────────────
# Transcription helper
# ─────────────────────────────────────────────
def _transcribe(audio_arr: np.ndarray) -> str:
segments, info = asr_model.transcribe(
audio_arr,
language="en",
beam_size=5,
vad_filter=True,
word_timestamps=False,
)
text = " ".join(seg.text for seg in segments).strip()
if not text:
return text
# Fix ALL-CAPS transcriptions (some audio conditions trigger this)
if text == text.upper():
text = text.lower()
return text[0].upper() + text[1:] if len(text) > 1 else text.upper()
# ─────────────────────────────────────────────
# POST /transcribe
# Accepts : multipart/form-data { audio: <file> }
# Returns : JSON { transcript, duration_sec, language }
# ─────────────────────────────────────────────
@app.post("/transcribe")
async def transcribe(
audio: UploadFile = File(...),
language: str = Form("en"), # kept for future multi-lang support
):
if not model_loaded:
raise HTTPException(status_code=503, detail="Model still loading, try again shortly")
raw = await audio.read()
if not raw:
raise HTTPException(status_code=400, detail="Uploaded file is empty")
try:
arr = _load_audio(raw)
except Exception as e:
logger.error(f"Audio load error: {e}", exc_info=True)
raise HTTPException(status_code=422, detail=f"Could not read audio: {e}")
duration_sec = round(len(arr) / 16000, 3)
try:
transcript = _transcribe(arr)
except Exception as e:
logger.error(f"Transcription error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if not transcript:
raise HTTPException(status_code=422, detail="No speech detected in audio")
return JSONResponse({
"transcript": transcript,
"duration_sec": duration_sec,
"language": "en",
})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |