CrazyMonkey0 commited on
Commit
9ea2744
·
1 Parent(s): 3ad9eac

fix(asr): load audio from in-memory buffer instead of disk

Browse files

Replaced file-based audio loading with io.BytesIO to handle uploaded audio directly in memory.
Librosa/SoundFile reads the buffer, resamples to 16kHz if needed, and feeds Whisper ASR model.

Files changed (1) hide show
  1. app/routes/asr.py +8 -4
app/routes/asr.py CHANGED
@@ -1,6 +1,8 @@
1
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
2
  from fastapi import APIRouter, Request, UploadFile, File
3
  import librosa
 
 
4
  import os
5
 
6
  router = APIRouter()
@@ -15,12 +17,14 @@ async def asr(request: Request, audio: UploadFile = File(...)):
15
  # Get the loaded ASR model and processor
16
  processor, model = request.app.state.processor_asr, request.app.state.model_asr
17
  # Audio file path
18
- audio_path = os.path.join(request.app.state.AUDIO_DIR, "temp", audio.filename)
19
- with open(audio_path, "wb") as f:
20
- f.write(await audio.read())
21
 
22
  # Loading audio file
23
- audio_data, sampling_rate = librosa.load(audio_path, sr=16000)
 
 
 
24
 
25
  # Preparing input data
26
  inputs = processor(audio_data, return_tensors="pt", sampling_rate=sampling_rate)
 
1
  from transformers import WhisperForConditionalGeneration, WhisperProcessor
2
  from fastapi import APIRouter, Request, UploadFile, File
3
  import librosa
4
+ import io
5
+ import soundfile as sf
6
  import os
7
 
8
  router = APIRouter()
 
17
  # Get the loaded ASR model and processor
18
  processor, model = request.app.state.processor_asr, request.app.state.model_asr
19
  # Audio file path
20
+ audio_bytes = await audio.read()
21
+ buffer = io.BytesIO(audio_bytes)
 
22
 
23
  # Loading audio file
24
+ audio_data, sampling_rate = sf.read(buffer, dtype="float32")
25
+ if sampling_rate != 16000:
26
+ audio_data = librosa.resample(audio_data, orig_sr=sampling_rate, target_sr=16000)
27
+ sampling_rate = 16000
28
 
29
  # Preparing input data
30
  inputs = processor(audio_data, return_tensors="pt", sampling_rate=sampling_rate)