KuyaToto commited on
Commit
1c1fe8b
·
verified ·
1 Parent(s): 7658fb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -25
app.py CHANGED
@@ -1,8 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- import numpy as np
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
- from scipy.signal import resample
6
 
7
  # Load model and processor
8
  model_id = "facebook/wav2vec2-large-960h-lv60-self"
@@ -10,43 +9,37 @@ processor = Wav2Vec2Processor.from_pretrained(model_id)
10
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
11
 
12
  # Transcription function
13
- def transcribe(audio_data):
14
- if audio_data is None:
15
  return "⚠️ No audio received."
16
 
17
- audio, sample_rate = audio_data
18
-
19
- # Convert stereo to mono if needed
20
- if len(audio.shape) == 2:
21
- audio = np.mean(audio, axis=1)
22
-
23
- # Ensure sample_rate is an integer
24
- sample_rate = int(sample_rate)
25
-
26
- # Resample to 16000 Hz if needed
27
  if sample_rate != 16000:
28
- number_of_samples = round(len(audio) * 16000 / sample_rate)
29
- audio = resample(audio, number_of_samples)
30
 
31
- # Normalize audio
32
- audio = audio.astype(np.float32)
 
 
 
33
 
34
- # Process and predict
35
- input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
36
  with torch.no_grad():
37
  logits = model(input_values).logits
38
- predicted_ids = torch.argmax(logits, dim=-1)
39
- transcription = processor.batch_decode(predicted_ids)[0]
 
40
 
41
  return transcription.lower()
42
 
43
- # Launch UI
44
  demo = gr.Interface(
45
  fn=transcribe,
46
- inputs=gr.Audio(sources=["microphone"], type="numpy", label="🎤 Speak now"),
47
  outputs=gr.Textbox(label="📝 Transcription"),
48
  title="Wav2Vec2 Speech Transcription",
49
- description="Speak into the microphone and get a transcription using Wav2Vec2 (Hugging Face)."
50
  )
51
 
52
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import torchaudio
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
5
 
6
  # Load model and processor
7
  model_id = "facebook/wav2vec2-large-960h-lv60-self"
 
9
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
10
 
11
  # Transcription function
12
+ def transcribe(audio_file):
13
+ if audio_file is None:
14
  return "⚠️ No audio received."
15
 
16
+ # Load and convert audio
17
+ waveform, sample_rate = torchaudio.load(audio_file)
 
 
 
 
 
 
 
 
18
  if sample_rate != 16000:
19
+ waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
20
+ sample_rate = 16000
21
 
22
+ # Only one channel (mono)
23
+ if waveform.shape[0] > 1:
24
+ waveform = waveform.mean(dim=0).unsqueeze(0)
25
+
26
+ input_values = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_values
27
 
 
 
28
  with torch.no_grad():
29
  logits = model(input_values).logits
30
+
31
+ predicted_ids = torch.argmax(logits, dim=-1)
32
+ transcription = processor.batch_decode(predicted_ids)[0]
33
 
34
  return transcription.lower()
35
 
36
+ # Gradio UI
37
  demo = gr.Interface(
38
  fn=transcribe,
39
+ inputs=gr.Audio(sources=["microphone"], type="filepath", label="🎤 Speak now"),
40
  outputs=gr.Textbox(label="📝 Transcription"),
41
  title="Wav2Vec2 Speech Transcription",
42
+ description="Speak into the microphone and get a transcription using Wav2Vec2 (via Hugging Face Transformers)."
43
  )
44
 
45
  demo.launch()