File size: 7,105 Bytes
6ad1b35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import io
import logging
import tempfile
import threading
import subprocess

import numpy as np
import soundfile as sf
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from fastapi.responses import JSONResponse
from faster_whisper import WhisperModel

# ─────────────────────────────────────────────
# Logging
# ─────────────────────────────────────────────
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s %(message)s"
)
logger = logging.getLogger(__name__)

# ─────────────────────────────────────────────
# App
# ─────────────────────────────────────────────
app = FastAPI(title="Klas English Transcription API")

# ─────────────────────────────────────────────
# Config
# ─────────────────────────────────────────────
ASR_MODEL_SIZE = os.getenv("ASR_MODEL_SIZE", "small.en")
DEVICE         = os.getenv("DEVICE", "cpu")          # HF Spaces is CPU by default
COMPUTE_TYPE   = os.getenv("COMPUTE_TYPE", "int8")   # int8 works well on CPU

# ─────────────────────────────────────────────
# Globals
# ─────────────────────────────────────────────
asr_model    = None
model_loaded = False


# ─────────────────────────────────────────────
# Model loading — background thread so the
# container passes HF's health check quickly
# ─────────────────────────────────────────────
def _load_model():
    global asr_model, model_loaded
    try:
        logger.info(f"Loading faster-whisper ({ASR_MODEL_SIZE}) on {DEVICE}/{COMPUTE_TYPE}")
        asr_model = WhisperModel(
            ASR_MODEL_SIZE,
            device=DEVICE,
            compute_type=COMPUTE_TYPE,
        )
        model_loaded = True
        logger.info("Model ready ✅")
    except Exception as e:
        logger.error(f"Model load failed: {e}", exc_info=True)


@app.on_event("startup")
async def startup():
    threading.Thread(target=_load_model, daemon=True).start()


# ─────────────────────────────────────────────
# Health
# ─────────────────────────────────────────────
@app.get("/ping")
def ping():
    return {"status": "ready" if model_loaded else "initializing"}


# ─────────────────────────────────────────────
# Audio helper
# ─────────────────────────────────────────────
def _load_audio(raw: bytes) -> np.ndarray:
    """Read audio bytes → float32 mono 16 kHz numpy array."""
    try:
        arr, sr = sf.read(io.BytesIO(raw), dtype="float32", always_2d=False)
    except Exception:
        # Fallback: treat as raw float32 PCM at 16 kHz
        arr = np.frombuffer(raw, dtype=np.float32)
        sr  = 16000

    # Stereo → mono
    if arr.ndim > 1:
        arr = arr.mean(axis=1)

    # Resample to 16 kHz if needed
    if sr != 16000:
        logger.info(f"Resampling {sr} Hz → 16000 Hz")
        tmp_in  = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        tmp_out = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        try:
            sf.write(tmp_in.name, arr, sr)
            subprocess.run(
                ["ffmpeg", "-y", "-i", tmp_in.name,
                 "-ar", "16000", "-ac", "1", tmp_out.name],
                check=True, capture_output=True,
            )
            arr, _ = sf.read(tmp_out.name, dtype="float32")
        finally:
            os.unlink(tmp_in.name)
            os.unlink(tmp_out.name)

    # Normalize
    peak = max(abs(arr.max()), abs(arr.min()), 1e-9)
    if peak > 1.0:
        arr = arr / peak

    return arr


# ─────────────────────────────────────────────
# Transcription helper
# ─────────────────────────────────────────────
def _transcribe(audio_arr: np.ndarray) -> str:
    segments, info = asr_model.transcribe(
        audio_arr,
        language="en",
        beam_size=5,
        vad_filter=True,
        word_timestamps=False,
    )
    text = " ".join(seg.text for seg in segments).strip()
    if not text:
        return text
    # Fix ALL-CAPS transcriptions (some audio conditions trigger this)
    if text == text.upper():
        text = text.lower()
    return text[0].upper() + text[1:] if len(text) > 1 else text.upper()


# ─────────────────────────────────────────────
# POST /transcribe
# Accepts : multipart/form-data  { audio: <file> }
# Returns : JSON { transcript, duration_sec, language }
# ─────────────────────────────────────────────
@app.post("/transcribe")
async def transcribe(
    audio: UploadFile = File(...),
    language: str = Form("en"),   # kept for future multi-lang support
):
    if not model_loaded:
        raise HTTPException(status_code=503, detail="Model still loading, try again shortly")

    raw = await audio.read()
    if not raw:
        raise HTTPException(status_code=400, detail="Uploaded file is empty")

    try:
        arr = _load_audio(raw)
    except Exception as e:
        logger.error(f"Audio load error: {e}", exc_info=True)
        raise HTTPException(status_code=422, detail=f"Could not read audio: {e}")

    duration_sec = round(len(arr) / 16000, 3)

    try:
        transcript = _transcribe(arr)
    except Exception as e:
        logger.error(f"Transcription error: {e}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

    if not transcript:
        raise HTTPException(status_code=422, detail="No speech detected in audio")

    return JSONResponse({
        "transcript":   transcript,
        "duration_sec": duration_sec,
        "language":     "en",
    })


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)