Grinding commited on
Commit
a428e05
·
verified ·
1 Parent(s): 1e98f31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -21
app.py CHANGED
@@ -1,9 +1,10 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from transformers import pipeline
3
- import torch
4
- import librosa
5
  import tempfile
6
  import os
 
 
7
 
8
  app = FastAPI()
9
 
@@ -12,7 +13,7 @@ try:
12
  asr_pipeline = pipeline(
13
  "automatic-speech-recognition",
14
  model="distil-whisper/distil-large-v3",
15
- torch_dtype=torch.float32,
16
  device="cpu",
17
  )
18
  print("✅ ASR model loaded successfully")
@@ -24,31 +25,66 @@ except Exception as e:
24
  async def transcribe_audio(audio_file: UploadFile = File(...)):
25
  if not asr_pipeline:
26
  raise HTTPException(status_code=503, detail="ASR model is not available.")
27
-
28
  audio_bytes = await audio_file.read()
29
 
30
- # Save to a temporary file to avoid librosa/numba caching/locator issues
31
- suffix = os.path.splitext(audio_file.filename)[1] or ".wav"
32
- tmp_path = None
33
  try:
34
- with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
35
- tmp.write(audio_bytes)
36
- tmp.flush()
37
- tmp_path = tmp.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # Load and resample to 16 kHz mono from the temporary path
40
- speech, sr = librosa.load(tmp_path, sr=16000, mono=True)
 
 
41
 
42
- # Perform transcription with chunking (keeps memory bounded)
 
 
 
 
 
 
 
 
 
 
 
 
43
  result = asr_pipeline(speech, chunk_length_s=30, stride_length_s=5)
44
- return {"transcription": result.get("text", "")}
 
 
 
 
45
 
46
  except Exception as e:
 
47
  raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}")
48
  finally:
49
- # Cleanup temp file if it was created
50
- if tmp_path and os.path.exists(tmp_path):
51
- try:
52
- os.remove(tmp_path)
53
- except Exception:
54
- pass
 
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
  from transformers import pipeline
3
+ import numpy as np
 
4
  import tempfile
5
  import os
6
+ import subprocess
7
+ import soundfile as sf
8
 
9
  app = FastAPI()
10
 
 
13
  asr_pipeline = pipeline(
14
  "automatic-speech-recognition",
15
  model="distil-whisper/distil-large-v3",
16
+ torch_dtype=None, # let pipeline pick sensible dtype
17
  device="cpu",
18
  )
19
  print("✅ ASR model loaded successfully")
 
25
  async def transcribe_audio(audio_file: UploadFile = File(...)):
26
  if not asr_pipeline:
27
  raise HTTPException(status_code=503, detail="ASR model is not available.")
28
+
29
  audio_bytes = await audio_file.read()
30
 
31
+ tmp_in = None
32
+ tmp_wav = None
 
33
  try:
34
+ # 1) Save uploaded bytes to a temporary file (preserve extension if available)
35
+ suffix = os.path.splitext(audio_file.filename)[1] or ""
36
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tf:
37
+ tf.write(audio_bytes)
38
+ tf.flush()
39
+ tmp_in = tf.name
40
+
41
+ # 2) Use ffmpeg to convert to 16kHz mono WAV PCM (stable, avoids librosa/numba)
42
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tfwav:
43
+ tmp_wav = tfwav.name
44
+
45
+ ffmpeg_cmd = [
46
+ "ffmpeg",
47
+ "-y", # overwrite
48
+ "-i", tmp_in, # input file
49
+ "-ar", "16000", # sample rate 16k
50
+ "-ac", "1", # mono
51
+ "-f", "wav",
52
+ tmp_wav
53
+ ]
54
 
55
+ proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
56
+ if proc.returncode != 0:
57
+ # include ffmpeg stderr for debugging
58
+ raise RuntimeError(f"ffmpeg error: {proc.stderr.strip()}")
59
 
60
+ # 3) Read WAV with soundfile into float32 waveform
61
+ speech, sr = sf.read(tmp_wav, dtype="float32")
62
+ if sr != 16000:
63
+ # should not happen because ffmpeg forced 16k, but check anyway
64
+ raise RuntimeError(f"Unexpected sample rate {sr}")
65
+
66
+ # 4) Transcribe using the transformers ASR pipeline
67
+ # Provide waveform as a 1-D numpy array
68
+ if speech.ndim > 1:
69
+ # ensure mono
70
+ speech = np.mean(speech, axis=1)
71
+
72
+ # chunking options to keep memory bounded
73
  result = asr_pipeline(speech, chunk_length_s=30, stride_length_s=5)
74
+ text = result.get("text", "") if isinstance(result, dict) else (
75
+ result[0].get("text", "") if isinstance(result, list) and result else ""
76
+ )
77
+
78
+ return {"transcription": text}
79
 
80
  except Exception as e:
81
+ # Return a 400 with a helpful message
82
  raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}")
83
  finally:
84
+ # cleanup temp files
85
+ for path in (tmp_in, tmp_wav):
86
+ if path and os.path.exists(path):
87
+ try:
88
+ os.remove(path)
89
+ except Exception:
90
+ pass