Grinding commited on
Commit
736aad2
·
verified ·
1 Parent(s): fb70c55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -58
app.py CHANGED
@@ -1,113 +1,90 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from faster_whisper import WhisperModel
3
  import numpy as np
4
  import tempfile
5
  import os
6
  import subprocess
7
  import soundfile as sf
8
- import logging
9
-
10
- # Configure logging for debugging
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
 
14
  app = FastAPI()
15
 
16
- # Load the ASR model on startup
17
  try:
18
- # This is the directory where download_model.py saved the converted model.
19
- cache_dir = os.getenv("HF_HOME", "/data/hf_cache")
20
- logger.info(f"Using model directory: {cache_dir}")
21
-
22
- # Check if the essential model file exists directly in the cache directory
23
- model_bin_path = os.path.join(cache_dir, "model.bin")
24
- if not os.path.exists(model_bin_path):
25
- logger.error(f"model.bin not found at {model_bin_path}")
26
- # Log contents for debugging
27
- if os.path.exists(cache_dir):
28
- logger.info(f"Contents of {cache_dir}: {os.listdir(cache_dir)}")
29
- asr_model = None
30
- else:
31
- logger.info(f"model.bin found at {model_bin_path}")
32
- # --- FIX: Load the model from the local directory path ---
33
- asr_model = WhisperModel(
34
- cache_dir, # Pass the local directory path instead of the repo ID
35
- device="cpu",
36
- compute_type="int8",
37
- )
38
- logger.info("✅ ASR model loaded successfully from local directory.")
39
-
40
  except Exception as e:
41
- asr_model = None
42
- logger.error(f"❌ Error loading ASR model: {e}")
43
 
44
  @app.post("/transcribe")
45
  async def transcribe_audio(audio_file: UploadFile = File(...)):
46
- if not asr_model:
47
- logger.error("ASR model is not available")
48
  raise HTTPException(status_code=503, detail="ASR model is not available.")
49
 
50
  audio_bytes = await audio_file.read()
 
51
  tmp_in = None
52
  tmp_wav = None
53
  try:
54
- # Save uploaded bytes to a temporary file
55
  suffix = os.path.splitext(audio_file.filename)[1] or ""
56
- with tempfile.NamedTemporaryFile(suffix=suffix, delete=False, dir="/tmp") as tf:
57
  tf.write(audio_bytes)
58
  tf.flush()
59
  tmp_in = tf.name
60
 
61
- # Convert to 16kHz mono WAV PCM using ffmpeg
62
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False, dir="/tmp") as tfwav:
63
  tmp_wav = tfwav.name
64
 
65
  ffmpeg_cmd = [
66
  "ffmpeg",
67
- "-y",
68
- "-i", tmp_in,
69
- "-ar", "16000",
70
- "-ac", "1",
71
  "-f", "wav",
72
  tmp_wav
73
  ]
74
 
75
  proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
76
  if proc.returncode != 0:
77
- logger.error(f"ffmpeg error: {proc.stderr.strip()}")
78
  raise RuntimeError(f"ffmpeg error: {proc.stderr.strip()}")
79
 
80
- # Read WAV with soundfile
81
  speech, sr = sf.read(tmp_wav, dtype="float32")
82
  if sr != 16000:
83
- logger.error(f"Unexpected sample rate {sr}")
84
  raise RuntimeError(f"Unexpected sample rate {sr}")
85
 
86
- # Ensure mono
 
87
  if speech.ndim > 1:
 
88
  speech = np.mean(speech, axis=1)
89
 
90
- # Transcribe using faster-whisper with optimized settings
91
- logger.info("Starting transcription")
92
- segments, _ = asr_model.transcribe(
93
- speech,
94
- beam_size=5,
95
- vad_filter=True, # Skip silence
96
- vad_parameters=dict(min_silence_duration_ms=500)
97
  )
98
- text = " ".join(segment.text.strip() for segment in segments)
99
- logger.info("Transcription completed")
100
 
101
  return {"transcription": text}
102
 
103
  except Exception as e:
104
- logger.error(f"Could not process audio file: {e}")
105
  raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}")
106
  finally:
107
- # Cleanup temp files
108
  for path in (tmp_in, tmp_wav):
109
  if path and os.path.exists(path):
110
  try:
111
  os.remove(path)
112
- except Exception as e:
113
- logger.warning(f"Failed to delete temp file {path}: {e}")
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from transformers import pipeline
3
  import numpy as np
4
  import tempfile
5
  import os
6
  import subprocess
7
  import soundfile as sf
 
 
 
 
 
8
 
9
  app = FastAPI()
10
 
11
+ # Load the ASR pipeline on startup
12
  try:
13
+ asr_pipeline = pipeline(
14
+ "automatic-speech-recognition",
15
+ model="distil-whisper/distil-large-v3",
16
+ torch_dtype=None, # let pipeline pick sensible dtype
17
+ device="cpu",
18
+ )
19
+ print("✅ ASR model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  except Exception as e:
21
+ asr_pipeline = None
22
+ print(f"❌ Error loading ASR model: {e}")
23
 
24
  @app.post("/transcribe")
25
  async def transcribe_audio(audio_file: UploadFile = File(...)):
26
+ if not asr_pipeline:
 
27
  raise HTTPException(status_code=503, detail="ASR model is not available.")
28
 
29
  audio_bytes = await audio_file.read()
30
+
31
  tmp_in = None
32
  tmp_wav = None
33
  try:
34
+ # 1) Save uploaded bytes to a temporary file (preserve extension if available)
35
  suffix = os.path.splitext(audio_file.filename)[1] or ""
36
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tf:
37
  tf.write(audio_bytes)
38
  tf.flush()
39
  tmp_in = tf.name
40
 
41
+ # 2) Use ffmpeg to convert to 16kHz mono WAV PCM (stable, avoids librosa/numba)
42
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tfwav:
43
  tmp_wav = tfwav.name
44
 
45
  ffmpeg_cmd = [
46
  "ffmpeg",
47
+ "-y", # overwrite
48
+ "-i", tmp_in, # input file
49
+ "-ar", "16000", # sample rate 16k
50
+ "-ac", "1", # mono
51
  "-f", "wav",
52
  tmp_wav
53
  ]
54
 
55
  proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
56
  if proc.returncode != 0:
57
+ # include ffmpeg stderr for debugging
58
  raise RuntimeError(f"ffmpeg error: {proc.stderr.strip()}")
59
 
60
+ # 3) Read WAV with soundfile into float32 waveform
61
  speech, sr = sf.read(tmp_wav, dtype="float32")
62
  if sr != 16000:
63
+ # should not happen because ffmpeg forced 16k, but check anyway
64
  raise RuntimeError(f"Unexpected sample rate {sr}")
65
 
66
+ # 4) Transcribe using the transformers ASR pipeline
67
+ # Provide waveform as a 1-D numpy array
68
  if speech.ndim > 1:
69
+ # ensure mono
70
  speech = np.mean(speech, axis=1)
71
 
72
+ # chunking options to keep memory bounded
73
+ result = asr_pipeline(speech, chunk_length_s=30, stride_length_s=5)
74
+ text = result.get("text", "") if isinstance(result, dict) else (
75
+ result[0].get("text", "") if isinstance(result, list) and result else ""
 
 
 
76
  )
 
 
77
 
78
  return {"transcription": text}
79
 
80
  except Exception as e:
81
+ # Return a 400 with a helpful message
82
  raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}")
83
  finally:
84
+ # cleanup temp files
85
  for path in (tmp_in, tmp_wav):
86
  if path and os.path.exists(path):
87
  try:
88
  os.remove(path)
89
+ except Exception:
90
+ pass