KuyaToto commited on
Commit
1cf7af2
·
verified ·
1 Parent(s): a644227

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -20
app.py CHANGED
@@ -2,43 +2,41 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
5
 
6
- # Load model
7
  model_id = "facebook/wav2vec2-large-960h-lv60-self"
8
  processor = Wav2Vec2Processor.from_pretrained(model_id)
9
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
10
 
11
- # Transcription function
12
- def transcribe(audio_np, sample_rate):
13
- if audio_np is None:
14
- return "No audio received."
15
 
16
- # Resample to 16kHz if needed
17
  if sample_rate != 16000:
18
- import torchaudio
19
- audio_tensor = torch.tensor(audio_np).unsqueeze(0)
20
- resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
21
- audio_tensor = resampler(audio_tensor)
22
- else:
23
- audio_tensor = torch.tensor(audio_np).unsqueeze(0)
24
 
25
- input_values = processor(audio_tensor.squeeze(), sampling_rate=16000, return_tensors="pt").input_values
 
26
 
 
27
  with torch.no_grad():
28
  logits = model(input_values).logits
 
 
29
 
30
- predicted_ids = torch.argmax(logits, dim=-1)
31
- transcription = processor.batch_decode(predicted_ids)[0]
32
  return transcription.lower()
33
 
34
- # Interface
35
  demo = gr.Interface(
36
  fn=transcribe,
37
- inputs=gr.Audio(source="microphone", type="numpy", label="Speak now"),
38
- outputs=gr.Textbox(label="Transcription"),
39
- live=False,
40
  title="Wav2Vec2 Speech Transcription",
41
- description="Speak into your mic and get a transcription using Wav2Vec2!"
42
  )
43
 
44
  demo.launch()
 
2
  import torch
3
  import numpy as np
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
+ from scipy.signal import resample
6
 
7
+ # Load model and processor
8
  model_id = "facebook/wav2vec2-large-960h-lv60-self"
9
  processor = Wav2Vec2Processor.from_pretrained(model_id)
10
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
11
 
12
+ # Transcribe function
13
+ def transcribe(audio, sample_rate):
14
+ if audio is None:
15
+ return "⚠️ No audio received."
16
 
17
+ # Resample audio to 16kHz if needed
18
  if sample_rate != 16000:
19
+ number_of_samples = round(len(audio) * float(16000) / sample_rate)
20
+ audio = resample(audio, number_of_samples)
 
 
 
 
21
 
22
+ # Prepare input
23
+ input_values = processor(audio, sampling_rate=16000, return_tensors="pt").input_values
24
 
25
+ # Run model
26
  with torch.no_grad():
27
  logits = model(input_values).logits
28
+ predicted_ids = torch.argmax(logits, dim=-1)
29
+ transcription = processor.batch_decode(predicted_ids)[0]
30
 
 
 
31
  return transcription.lower()
32
 
33
+ # Gradio interface
34
  demo = gr.Interface(
35
  fn=transcribe,
36
+ inputs=gr.Audio(source="microphone", type="numpy", label="🎤 Speak now"),
37
+ outputs=gr.Textbox(label="📝 Transcription"),
 
38
  title="Wav2Vec2 Speech Transcription",
39
+ description="Speak and get real-time transcription using Wav2Vec2 (Hugging Face)."
40
  )
41
 
42
  demo.launch()