Grinding commited on
Commit
6b7471c
·
verified ·
1 Parent(s): 6755aad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -65
app.py CHANGED
@@ -1,102 +1,90 @@
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from typing import Optional
3
  import numpy as np
4
  import tempfile
5
  import os
6
  import subprocess
7
  import soundfile as sf
8
 
9
- from transformers import WhisperProcessor
10
- from optimum.onnxruntime import ORTModelForSeq2SeqLM
11
-
12
  app = FastAPI()
13
 
14
- # --- Model/Processor Repos ---
15
- # ONNX (quantized) model repo
16
- ONNX_REPO = "distil-whisper/distil-large-v3.5-ONNX"
17
- # Processor (feature extractor + tokenizer) repo
18
- PROCESSOR_REPO = "distil-whisper/distil-large-v3"
19
-
20
- # --- Load ONNX model + processor (CPU) ---
21
  try:
22
- processor = WhisperProcessor.from_pretrained(PROCESSOR_REPO)
23
- model = ORTModelForSeq2SeqLM.from_pretrained(ONNX_REPO)
24
- print("✅ ONNX model & processor loaded")
 
 
 
 
25
  except Exception as e:
26
- processor = None
27
- model = None
28
- print(f"❌ Error loading ONNX model/processor: {e}")
29
-
30
- def convert_to_16k_mono_wav(in_path: str) -> str:
31
- """Use ffmpeg to produce a temporary 16kHz mono WAV file."""
32
- out_fd = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
33
- out_fd.close()
34
- out_path = out_fd.name
35
- cmd = [
36
- "ffmpeg", "-y",
37
- "-i", in_path,
38
- "-ar", "16000",
39
- "-ac", "1",
40
- "-f", "wav",
41
- out_path
42
- ]
43
- proc = subprocess.run(cmd, capture_output=True, text=True)
44
- if proc.returncode != 0:
45
- raise RuntimeError(f"ffmpeg error: {proc.stderr.strip()}")
46
- return out_path
47
 
48
  @app.post("/transcribe")
49
- async def transcribe_audio(
50
- audio_file: UploadFile = File(...),
51
- language: Optional[str] = None, # e.g., "en"
52
- max_new_tokens: int = 128
53
- ):
54
- if model is None or processor is None:
55
- raise HTTPException(status_code=503, detail="ONNX ASR model is not available.")
56
 
57
  tmp_in = None
58
  tmp_wav = None
59
  try:
60
- # Save upload to temp file
61
  suffix = os.path.splitext(audio_file.filename)[1] or ""
62
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tf:
63
- tf.write(await audio_file.read())
64
  tf.flush()
65
  tmp_in = tf.name
66
 
67
- # Normalize to 16kHz mono WAV
68
- tmp_wav = convert_to_16k_mono_wav(tmp_in)
 
69
 
70
- # Read waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  speech, sr = sf.read(tmp_wav, dtype="float32")
72
  if sr != 16000:
 
73
  raise RuntimeError(f"Unexpected sample rate {sr}")
74
 
 
 
75
  if speech.ndim > 1:
 
76
  speech = np.mean(speech, axis=1)
77
 
78
- # Prepare inputs
79
- forced_lang = language or processor.tokenizer.language if hasattr(processor, "tokenizer") else None
80
- inputs = processor(speech, sampling_rate=16000, return_tensors="np")
81
-
82
- # Generate (ORT optimized)
83
- gen_kwargs = dict(
84
- max_new_tokens=max_new_tokens,
85
- do_sample=False
86
  )
87
- # If you want to force a specific language, set the language token
88
- if forced_lang and hasattr(processor.tokenizer, "set_prefix_tokens"):
89
- processor.tokenizer.set_prefix_tokens(language=forced_lang)
90
-
91
- generated_ids = model.generate(**inputs, **gen_kwargs)
92
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
93
 
94
  return {"transcription": text}
95
 
96
  except Exception as e:
 
97
  raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}")
98
  finally:
99
- for p in (tmp_in, tmp_wav):
100
- if p and os.path.exists(p):
101
- try: os.remove(p)
102
- except Exception: pass
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from transformers import pipeline
3
  import numpy as np
4
  import tempfile
5
  import os
6
  import subprocess
7
  import soundfile as sf
8
 
 
 
 
9
  app = FastAPI()
10
 
11
+ # Load the ASR pipeline on startup
 
 
 
 
 
 
12
  try:
13
+ asr_pipeline = pipeline(
14
+ "automatic-speech-recognition",
15
+ model="distil-whisper/distil-large-v3",
16
+ torch_dtype=None, # let pipeline pick sensible dtype
17
+ device="cpu",
18
+ )
19
+ print("✅ ASR model loaded successfully")
20
  except Exception as e:
21
+ asr_pipeline = None
22
+ print(f"❌ Error loading ASR model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  @app.post("/transcribe")
25
+ async def transcribe_audio(audio_file: UploadFile = File(...)):
26
+ if not asr_pipeline:
27
+ raise HTTPException(status_code=503, detail="ASR model is not available.")
28
+
29
+ audio_bytes = await audio_file.read()
 
 
30
 
31
  tmp_in = None
32
  tmp_wav = None
33
  try:
34
+ # 1) Save uploaded bytes to a temporary file (preserve extension if available)
35
  suffix = os.path.splitext(audio_file.filename)[1] or ""
36
  with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tf:
37
+ tf.write(audio_bytes)
38
  tf.flush()
39
  tmp_in = tf.name
40
 
41
+ # 2) Use ffmpeg to convert to 16kHz mono WAV PCM (stable, avoids librosa/numba)
42
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tfwav:
43
+ tmp_wav = tfwav.name
44
 
45
+ ffmpeg_cmd = [
46
+ "ffmpeg",
47
+ "-y", # overwrite
48
+ "-i", tmp_in, # input file
49
+ "-ar", "16000", # sample rate 16k
50
+ "-ac", "1", # mono
51
+ "-f", "wav",
52
+ tmp_wav
53
+ ]
54
+
55
+ proc = subprocess.run(ffmpeg_cmd, capture_output=True, text=True)
56
+ if proc.returncode != 0:
57
+ # include ffmpeg stderr for debugging
58
+ raise RuntimeError(f"ffmpeg error: {proc.stderr.strip()}")
59
+
60
+ # 3) Read WAV with soundfile into float32 waveform
61
  speech, sr = sf.read(tmp_wav, dtype="float32")
62
  if sr != 16000:
63
+ # should not happen because ffmpeg forced 16k, but check anyway
64
  raise RuntimeError(f"Unexpected sample rate {sr}")
65
 
66
+ # 4) Transcribe using the transformers ASR pipeline
67
+ # Provide waveform as a 1-D numpy array
68
  if speech.ndim > 1:
69
+ # ensure mono
70
  speech = np.mean(speech, axis=1)
71
 
72
+ # chunking options to keep memory bounded
73
+ result = asr_pipeline(speech, chunk_length_s=30, stride_length_s=5)
74
+ text = result.get("text", "") if isinstance(result, dict) else (
75
+ result[0].get("text", "") if isinstance(result, list) and result else ""
 
 
 
 
76
  )
 
 
 
 
 
 
77
 
78
  return {"transcription": text}
79
 
80
  except Exception as e:
81
+ # Return a 400 with a helpful message
82
  raise HTTPException(status_code=400, detail=f"Could not process audio file: {e}")
83
  finally:
84
+ # cleanup temp files
85
+ for path in (tmp_in, tmp_wav):
86
+ if path and os.path.exists(path):
87
+ try:
88
+ os.remove(path)
89
+ except Exception:
90
+ pass