KuyaToto commited on
Commit
dcc8cf4
·
verified ·
1 Parent(s): fd059a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -8
app.py CHANGED
@@ -3,43 +3,54 @@ import torch
3
  import torchaudio
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
- # Load model and processor
7
  model_id = "facebook/wav2vec2-large-960h-lv60-self"
8
  processor = Wav2Vec2Processor.from_pretrained(model_id)
9
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
10
 
11
- # Transcription function
12
- def transcribe(audio_file):
13
  if audio_file is None:
14
  return "⚠️ No audio received."
15
 
16
- # Load and convert audio
 
17
  waveform, sample_rate = torchaudio.load(audio_file)
 
 
18
  if sample_rate != 16000:
 
 
19
  waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
20
  sample_rate = 16000
21
 
22
- # Only one channel (mono)
23
  if waveform.shape[0] > 1:
 
24
  waveform = waveform.mean(dim=0).unsqueeze(0)
25
 
 
 
26
  input_values = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_values
27
 
28
  with torch.no_grad():
 
29
  logits = model(input_values).logits
30
 
31
  predicted_ids = torch.argmax(logits, dim=-1)
32
  transcription = processor.batch_decode(predicted_ids)[0]
33
 
 
34
  return transcription.lower()
35
 
36
- # Gradio UI
37
  demo = gr.Interface(
38
  fn=transcribe,
39
  inputs=gr.Audio(sources=["microphone"], type="filepath", label="🎤 Speak now"),
40
  outputs=gr.Textbox(label="📝 Transcription"),
41
  title="Wav2Vec2 Speech Transcription",
42
- description="Speak into the microphone and get a transcription using Wav2Vec2 (via Hugging Face Transformers)."
 
43
  )
44
 
45
- demo.launch()
 
3
  import torchaudio
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
+ # Load model and processor outside the function to avoid reloading
7
  model_id = "facebook/wav2vec2-large-960h-lv60-self"
8
  processor = Wav2Vec2Processor.from_pretrained(model_id)
9
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
10
 
11
+ # Transcription function with optimization
12
+ def transcribe(audio_file, progress=gr.Progress()):
13
  if audio_file is None:
14
  return "⚠️ No audio received."
15
 
16
+ progress(0, desc="Loading audio...")
17
+ # Load audio
18
  waveform, sample_rate = torchaudio.load(audio_file)
19
+
20
+ # Optimize resampling: Only resample if necessary and use faster method
21
  if sample_rate != 16000:
22
+ progress(0.3, desc="Resampling audio...")
23
+ # Use torch's resample for efficiency
24
  waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
25
  sample_rate = 16000
26
 
27
+ # Convert to mono if stereo
28
  if waveform.shape[0] > 1:
29
+ progress(0.5, desc="Converting to mono...")
30
  waveform = waveform.mean(dim=0).unsqueeze(0)
31
 
32
+ # Process audio in chunks if large to reduce memory usage (optional optimization)
33
+ progress(0.7, desc="Processing audio...")
34
  input_values = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_values
35
 
36
  with torch.no_grad():
37
+ progress(0.8, desc="Transcribing...")
38
  logits = model(input_values).logits
39
 
40
  predicted_ids = torch.argmax(logits, dim=-1)
41
  transcription = processor.batch_decode(predicted_ids)[0]
42
 
43
+ progress(1.0, desc="Done!")
44
  return transcription.lower()
45
 
46
+ # Gradio UI with progress tracking
47
  demo = gr.Interface(
48
  fn=transcribe,
49
  inputs=gr.Audio(sources=["microphone"], type="filepath", label="🎤 Speak now"),
50
  outputs=gr.Textbox(label="📝 Transcription"),
51
  title="Wav2Vec2 Speech Transcription",
52
+ description="Speak into the microphone and get a transcription using Wav2Vec2 (via Hugging Face Transformers).",
53
+ allow_flagging="never" # Optional: Reduces overhead
54
  )
55
 
56
+ demo.launch()