nambn0321 commited on
Commit
0663839
·
verified ·
1 Parent(s): 53b8002

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -9
app.py CHANGED
@@ -3,29 +3,26 @@ import torchaudio
3
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
4
  import gradio as gr
5
 
6
- # Load model and processor from your fine-tuned directory
7
- MODEL_PATH = r"nambn0321/ASR_models"
8
  processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
9
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).eval()
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model.to(device)
12
 
13
- # Define inference function
14
  def transcribe(audio):
15
  if audio is None:
16
  return "No audio provided."
17
 
18
  sr, data = audio
19
-
20
- # Convert to mono and resample to 16kHz if needed
21
  waveform = torch.tensor(data).unsqueeze(0)
 
22
  if sr != 16000:
23
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
24
  waveform = resampler(waveform)
 
25
  if waveform.shape[0] > 1:
26
  waveform = waveform.mean(dim=0, keepdim=True)
27
 
28
- # Inference
29
  inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
30
  input_values = inputs.input_values.to(device)
31
 
@@ -36,12 +33,10 @@ def transcribe(audio):
36
  transcription = processor.batch_decode(predicted_ids)[0]
37
  return transcription.strip()
38
 
39
- # Gradio interface
40
  gr.Interface(
41
  fn=transcribe,
42
- inputs=gr.Audio(source="upload", type="numpy", label="Upload WAV/MP3 file"),
43
  outputs=gr.Textbox(label="Transcription"),
44
  title="🗣️ ASR Demo with Wav2Vec2",
45
  description="Upload an audio file (WAV or MP3) and get the transcription using your fine-tuned model.",
46
- live=False
47
  ).launch()
 
3
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
4
  import gradio as gr
5
 
6
+ MODEL_PATH = "nambn0321/ASR_models" # Your HF model repo
 
7
  processor = Wav2Vec2Processor.from_pretrained(MODEL_PATH)
8
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_PATH).eval()
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  model.to(device)
11
 
 
12
  def transcribe(audio):
13
  if audio is None:
14
  return "No audio provided."
15
 
16
  sr, data = audio
 
 
17
  waveform = torch.tensor(data).unsqueeze(0)
18
+
19
  if sr != 16000:
20
  resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
21
  waveform = resampler(waveform)
22
+
23
  if waveform.shape[0] > 1:
24
  waveform = waveform.mean(dim=0, keepdim=True)
25
 
 
26
  inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt", padding=True)
27
  input_values = inputs.input_values.to(device)
28
 
 
33
  transcription = processor.batch_decode(predicted_ids)[0]
34
  return transcription.strip()
35
 
 
36
  gr.Interface(
37
  fn=transcribe,
38
+ inputs=gr.Audio(type="numpy", label="Upload WAV/MP3 file"),
39
  outputs=gr.Textbox(label="Transcription"),
40
  title="🗣️ ASR Demo with Wav2Vec2",
41
  description="Upload an audio file (WAV or MP3) and get the transcription using your fine-tuned model.",
 
42
  ).launch()