Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from transformers import pipeline | |
| import numpy as np | |
| import tempfile | |
| import os | |
| import subprocess | |
| import soundfile as sf | |
| app = FastAPI() | |
| # Load the ASR pipeline on startup | |
| try: | |
| asr_pipeline = pipeline( | |
| "automatic-speech-recognition", | |
| model="distil-whisper/distil-large-v3", | |
| torch_dtype=None, # let pipeline pick sensible dtype | |
| device="cpu", | |
| ) | |
| print("✅ ASR model loaded successfully") | |
| except Exception as e: | |
| asr_pipeline = None | |
| print(f"❌ Error loading ASR model: {e}") | |
| async def transcribe_audio(audio_file: UploadFile = File(...)): | |
| if not asr_pipeline: | |
| raise HTTPException(status_code=503, detail="ASR model is not available.") | |
| audio_bytes = await audio_file.read() | |
| tmp_in = None | |
| tmp_wav = None | |
| try: | |
| # 1) Save uploaded bytes to a temporary file (preserve extension if available) | |
| suffix = os.path.splitext(audio_file.filename)[1] or "" | |
| with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tf: | |
| tf.write(audio_bytes) | |
| tf.flush() | |
| tmp_in = tf.name | |
| # 2) Use ffmpeg to convert to 16kHz mono WAV PCM (stable, avoids librosa/numba) | |
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tfwav: | |
| tmp_wav = tfwav.name | |
| ffmpeg_cmd = [ | |
| "ffmpeg", | |
| "-y", # overwrite | |
| "-i", tmp_in, # input file | |
| "-ar", "16000", # sample rate 16k | |
| "-ac", "1", # mono | |
| "-f", "wav", | |
| tmp_wav | |
| ] | |
| proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True) | |
| if proc.returncode != 0: | |
| # include ffmpeg stderr for debugging | |
| raise RuntimeError(f"ffmpeg error: {proc.stderr.strip()}") | |
| # 3) Read WAV with soundfile into float32 waveform | |
| speech, sr = sf.read(tmp_wav, dtype="float32") | |
| if sr != 16000: | |
| # should not happen because ffmpeg forced 16k, but check anyway | |
| raise RuntimeError(f"Unexpected sample rate {sr}") | |
| # 4) Transcribe using the transformers ASR pipeline | |
| # Provide waveform as a 1-D numpy array | |
| if speech.ndim > 1: | |
| # ensure mono | |
| speech = np.mean(speech, axis=1) | |
| # chunking options to keep memory bounded | |
| result = asr_pipeline(speech, chunk_length_s=30, stride_length_s=5) | |
| text = result.get("text", "") if isinstance(result, dict) else ( | |
| result[0].get("text", "") if isinstance(result, list) and result else "" | |
| ) | |
| return {"transcription": text} | |
| except Exception as e: | |
| # Return a 400 with a helpful message | |
| raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}") | |
| finally: | |
| # cleanup temp files | |
| for path in (tmp_in, tmp_wav): | |
| if path and os.path.exists(path): | |
| try: | |
| os.remove(path) | |
| except Exception: | |
| pass | |