harshananddev commited on
Commit
b869493
·
verified ·
1 Parent(s): f3d5b5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -26
app.py CHANGED
@@ -1,40 +1,52 @@
1
- !pip install transformers
2
-
3
-
4
  import gradio as gr
5
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
6
  import torch
7
  import torchaudio
8
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Load pre-trained model and tokenizer
10
  model_name = "facebook/wav2vec2-base-960h"
11
  tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
12
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
13
 
14
  def speech_to_text(audio):
15
- # Load audio file
16
- waveform, rate = torchaudio.load(audio.name)
17
-
18
- # Ensure the audio is mono
19
- if waveform.shape[0] > 1:
20
- waveform = torch.mean(waveform, dim=0, keepdim=True)
21
-
22
- # Resample to 16000 Hz
23
- resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)
24
- waveform = resampler(waveform)
25
-
26
- # Tokenize the waveform
27
- inputs = tokenizer(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
28
-
29
- # Perform inference
30
- with torch.no_grad():
31
- logits = model(**inputs).logits
32
-
33
- # Decode the output
34
- predicted_ids = torch.argmax(logits, dim=-1)
35
- transcription = tokenizer.batch_decode(predicted_ids)[0]
36
-
37
- return transcription
 
 
 
38
 
39
  # Create Gradio interface
40
  iface = gr.Interface(
 
 
 
 
1
  import gradio as gr
2
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
3
  import torch
4
  import torchaudio
5
 
6
+ # Install the necessary packages
7
+ import subprocess
8
+ import sys
9
+
10
+ def install(package):
11
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
12
+
13
+ install("transformers")
14
+ install("torch")
15
+ install("torchaudio")
16
+ install("gradio")
17
+
18
  # Load pre-trained model and tokenizer
19
  model_name = "facebook/wav2vec2-base-960h"
20
  tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
21
  model = Wav2Vec2ForCTC.from_pretrained(model_name)
22
 
23
  def speech_to_text(audio):
24
+ try:
25
+ # Load audio file
26
+ waveform, rate = torchaudio.load(audio.name)
27
+
28
+ # Ensure the audio is mono
29
+ if waveform.shape[0] > 1:
30
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
31
+
32
+ # Resample to 16000 Hz
33
+ resampler = torchaudio.transforms.Resample(orig_freq=rate, new_freq=16000)
34
+ waveform = resampler(waveform)
35
+
36
+ # Tokenize the waveform
37
+ inputs = tokenizer(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000)
38
+
39
+ # Perform inference
40
+ with torch.no_grad():
41
+ logits = model(**inputs).logits
42
+
43
+ # Decode the output
44
+ predicted_ids = torch.argmax(logits, dim=-1)
45
+ transcription = tokenizer.batch_decode(predicted_ids)[0]
46
+
47
+ return transcription
48
+ except Exception as e:
49
+ return str(e)
50
 
51
  # Create Gradio interface
52
  iface = gr.Interface(