01Yassine commited on
Commit
87e47e2
·
verified ·
1 Parent(s): e1b40db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -9
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  import torchaudio
3
  from transformers import pipeline
 
 
4
 
5
  # Load only the Moul-Sout-100 model
6
  asr_pipeline = pipeline("automatic-speech-recognition", model="01Yassine/moulsot_v0.2_1000")
@@ -10,31 +12,76 @@ asr_pipeline.model.generation_config.input_ids = asr_pipeline.model.generation_c
10
  asr_pipeline.model.generation_config.forced_decoder_ids = None
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def ensure_mono_16k(audio_path):
14
- """Load audio, convert to mono + 16kHz, and save a temp version."""
15
- waveform, sr = torchaudio.load(audio_path)
16
-
17
- # Convert to mono if necessary
18
  if waveform.shape[0] > 1:
19
  waveform = waveform.mean(dim=0, keepdim=True)
20
-
21
- # Resample to 16kHz if necessary
22
  if sr != 16000:
23
  resampler = torchaudio.transforms.Resample(sr, 16000)
24
  waveform = resampler(waveform)
25
  sr = 16000
26
-
27
- tmp_path = "/tmp/processed_16k.wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  torchaudio.save(tmp_path, waveform, sr)
29
  return tmp_path
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def transcribe(audio):
33
  if audio is None:
34
  return "Please record or upload an audio file."
35
 
36
  # Process and transcribe
37
- processed_audio = ensure_mono_16k(audio)
38
  result = asr_pipeline(processed_audio)["text"]
39
 
40
  return result
 
1
  import gradio as gr
2
  import torchaudio
3
  from transformers import pipeline
4
+ import soundfile as sf
5
+ import torch
6
 
7
  # Load only the Moul-Sout-100 model
8
  asr_pipeline = pipeline("automatic-speech-recognition", model="01Yassine/moulsot_v0.2_1000")
 
12
  asr_pipeline.model.generation_config.forced_decoder_ids = None
13
 
14
 
15
+ def load_audio(audio_path):
16
+ """Robustly load any audio file into (waveform, sr)"""
17
+ try:
18
+ waveform, sr = torchaudio.load(audio_path)
19
+ except Exception:
20
+ # fallback for unknown backends
21
+ data, sr = sf.read(audio_path)
22
+ waveform = torch.tensor(data, dtype=torch.float32).T
23
+ if waveform.ndim == 1:
24
+ waveform = waveform.unsqueeze(0)
25
+ return waveform, sr
26
+
27
+
28
  def ensure_mono_16k(audio_path):
29
+ """Convert audio to mono + 16 kHz"""
30
+ waveform, sr = load_audio(audio_path)
 
 
31
  if waveform.shape[0] > 1:
32
  waveform = waveform.mean(dim=0, keepdim=True)
 
 
33
  if sr != 16000:
34
  resampler = torchaudio.transforms.Resample(sr, 16000)
35
  waveform = resampler(waveform)
36
  sr = 16000
37
+ return waveform, sr
38
+
39
+
40
+ def trim_leading_silence(waveform, sr, keep_ms=100, threshold=0.01):
41
+ """Trim leading silence, keep ≤ keep_ms ms"""
42
+ energy = waveform.abs().mean(dim=0)
43
+ non_silence_idx = (energy > threshold).nonzero(as_tuple=True)[0]
44
+ if len(non_silence_idx) == 0:
45
+ return waveform # all silence
46
+ first_non_silence = non_silence_idx[0].item()
47
+ keep_samples = int(sr * (keep_ms / 1000.0))
48
+ start = max(0, first_non_silence - keep_samples)
49
+ return waveform[:, start:]
50
+
51
+
52
+ def preprocess_audio(audio_path):
53
+ waveform, sr = ensure_mono_16k(audio_path)
54
+ waveform = trim_leading_silence(waveform, sr, keep_ms=100, threshold=0.01)
55
+ tmp_path = "/tmp/processed_trimmed.wav"
56
  torchaudio.save(tmp_path, waveform, sr)
57
  return tmp_path
58
 
59
 
60
+ # def ensure_mono_16k(audio_path):
61
+ # """Load audio, convert to mono + 16kHz, and save a temp version."""
62
+ # waveform, sr = torchaudio.load(audio_path)
63
+
64
+ # # Convert to mono if necessary
65
+ # if waveform.shape[0] > 1:
66
+ # waveform = waveform.mean(dim=0, keepdim=True)
67
+
68
+ # # Resample to 16kHz if necessary
69
+ # if sr != 16000:
70
+ # resampler = torchaudio.transforms.Resample(sr, 16000)
71
+ # waveform = resampler(waveform)
72
+ # sr = 16000
73
+
74
+ # tmp_path = "/tmp/processed_16k.wav"
75
+ # torchaudio.save(tmp_path, waveform, sr)
76
+ # return tmp_path
77
+
78
+
79
  def transcribe(audio):
80
  if audio is None:
81
  return "Please record or upload an audio file."
82
 
83
  # Process and transcribe
84
+ processed_audio = preprocess_audio(audio)
85
  result = asr_pipeline(processed_audio)["text"]
86
 
87
  return result