KuyaToto commited on
Commit
9c52375
·
verified ·
1 Parent(s): 3033349

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -16
app.py CHANGED
@@ -3,52 +3,40 @@ import torch
3
  import torchaudio
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
- # Load the base model and processor outside the function to avoid reloading
7
  model_id = "facebook/wav2vec2-base-960h"
8
  processor = Wav2Vec2Processor.from_pretrained(model_id)
9
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
10
 
11
- # Transcription function with optimization
12
  def transcribe(audio_file, progress=gr.Progress()):
13
  if audio_file is None:
14
  return "⚠️ No audio received."
15
 
16
- print(f"📥 Received file path: {audio_file}") # ✅ Log for debugging
17
-
18
- progress(0, desc="Loading audio...")
19
  waveform, sample_rate = torchaudio.load(audio_file)
20
 
21
  if sample_rate != 16000:
22
- progress(0.3, desc="Resampling audio...")
23
  waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
24
  sample_rate = 16000
25
 
26
  if waveform.shape[0] > 1:
27
- progress(0.5, desc="Converting to mono...")
28
  waveform = waveform.mean(dim=0).unsqueeze(0)
29
 
30
- progress(0.7, desc="Processing audio...")
31
  input_values = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_values
32
 
33
  with torch.no_grad():
34
- progress(0.8, desc="Transcribing...")
35
  logits = model(input_values).logits
36
 
37
  predicted_ids = torch.argmax(logits, dim=-1)
38
  transcription = processor.batch_decode(predicted_ids)[0]
39
-
40
- progress(1.0, desc="Done!")
41
  return transcription.lower()
42
 
43
- # Gradio UI with POST API support
44
  demo = gr.Interface(
45
  fn=transcribe,
46
  inputs=gr.Audio(sources=["microphone"], type="filepath", label="🎤 Speak now"),
47
  outputs=gr.Textbox(label="📝 Transcription"),
48
  title="Wav2Vec2 Speech Transcription",
49
- description="Speak into the microphone and get a transcription using Wav2Vec2-base (via Hugging Face Transformers).",
50
- allow_flagging="never"
51
  )
52
 
53
- # ✅ Enable API mode to allow POST requests
54
- demo.launch(api=True)
 
3
  import torchaudio
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
+ # Load the model
7
  model_id = "facebook/wav2vec2-base-960h"
8
  processor = Wav2Vec2Processor.from_pretrained(model_id)
9
  model = Wav2Vec2ForCTC.from_pretrained(model_id)
10
 
 
11
  def transcribe(audio_file, progress=gr.Progress()):
12
  if audio_file is None:
13
  return "⚠️ No audio received."
14
 
 
 
 
15
  waveform, sample_rate = torchaudio.load(audio_file)
16
 
17
  if sample_rate != 16000:
 
18
  waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=16000)
19
  sample_rate = 16000
20
 
21
  if waveform.shape[0] > 1:
 
22
  waveform = waveform.mean(dim=0).unsqueeze(0)
23
 
 
24
  input_values = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_values
25
 
26
  with torch.no_grad():
 
27
  logits = model(input_values).logits
28
 
29
  predicted_ids = torch.argmax(logits, dim=-1)
30
  transcription = processor.batch_decode(predicted_ids)[0]
 
 
31
  return transcription.lower()
32
 
 
33
  demo = gr.Interface(
34
  fn=transcribe,
35
  inputs=gr.Audio(sources=["microphone"], type="filepath", label="🎤 Speak now"),
36
  outputs=gr.Textbox(label="📝 Transcription"),
37
  title="Wav2Vec2 Speech Transcription",
38
+ description="Speak into the microphone and get a transcription using Wav2Vec2-base.",
39
+ flagging_mode="never"
40
  )
41
 
42
+ demo.launch()