Somalitts commited on
Commit
df98aad
·
verified ·
1 Parent(s): 568f26c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -20
app.py CHANGED
@@ -1,32 +1,32 @@
1
- import torch
2
  import torchaudio
 
3
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
4
- import gradio as gr
5
 
6
- model = Wav2Vec2ForCTC.from_pretrained("tacab/tacab_asr_somali")
7
- processor = Wav2Vec2Processor.from_pretrained("tacab/tacab_asr_somali")
8
-
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- model.to(device)
11
 
12
  def transcribe(audio):
13
  waveform, sample_rate = torchaudio.load(audio)
 
14
  if sample_rate != 16000:
15
- waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)
16
- if waveform.shape[0] > 1:
17
- waveform = waveform.mean(dim=0, keepdim=True)
18
- inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
19
- input_values = inputs.input_values.to(device)
20
  with torch.no_grad():
21
- logits = model(input_values).logits
 
22
  predicted_ids = torch.argmax(logits, dim=-1)
23
- transcription = processor.batch_decode(predicted_ids)[0]
24
- return transcription.lower()
25
 
26
- gr.Interface(
 
27
  fn=transcribe,
28
- inputs=gr.Audio(type="filepath", label="🎙️ Ku hadal Af Soomaali"),
29
- outputs=gr.Text(label="📄 Qoraalka la helay"),
30
- title="Tacab ASR Somali",
31
- description="ASR model for Somali speech-to-text using Wav2Vec2.",
32
  ).launch()
 
1
+ import gradio as gr
2
  import torchaudio
3
+ import torch
4
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
 
5
 
6
+ # Load model and processor
7
+ processor = Wav2Vec2Processor.from_pretrained("Mustafaa4a/ASR-Somali")
8
+ model = Wav2Vec2ForCTC.from_pretrained("Mustafaa4a/ASR-Somali")
 
 
9
 
10
  def transcribe(audio):
11
  waveform, sample_rate = torchaudio.load(audio)
12
+
13
  if sample_rate != 16000:
14
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
15
+ waveform = resampler(waveform)
16
+
17
+ inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt")
 
18
  with torch.no_grad():
19
+ logits = model(**inputs).logits
20
+
21
  predicted_ids = torch.argmax(logits, dim=-1)
22
+ transcription = processor.decode(predicted_ids[0])
23
+ return transcription
24
 
25
+ # Gradio Interface setup
26
+ interface = gr.Interface(
27
  fn=transcribe,
28
+ inputs=gr.Audio(type="filepath", label="Upload Somali Audio (.wav)"),
29
+ outputs=gr.Textbox(label="Transcription"),
30
+ title="Somali-speech_to_text",
31
+ description="Upload a Somali speech audio file (mono WAV, 16kHz) and get the text transcription."
32
  ).launch()