omar1232 commited on
Commit
fc78b67
·
verified ·
1 Parent(s): 7993fc2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -21
app.py CHANGED
@@ -1,44 +1,51 @@
1
  import gradio as gr
2
- import speech_recognition as sr
3
  from pydub import AudioSegment
4
  import tempfile
5
  from langdetect import detect
6
  import os
7
  import asyncio
 
 
8
  from telegram import Update
9
  from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
10
 
11
  # Telegram bot token (to be set via Hugging Face Space secrets)
12
  TELEGRAM_BOT_TOKEN = os.getenv("TELEGRAM_BOT_TOKEN")
13
 
 
 
 
 
14
  # Process audio and transcribe
15
  def process_audio(audio_input):
16
- recognizer = sr.Recognizer()
17
-
18
- # Convert all audio inputs to WAV format
19
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
20
  if isinstance(audio_input, tuple): # Recorded audio (sample_rate, numpy_array)
21
  sample_rate, audio_data = audio_input
22
- AudioSegment(audio_data, sample_rate=sample_rate, frame_rate=sample_rate, channels=1).export(temp_file.name, format="wav")
23
  else: # Uploaded audio file (file path or Telegram file)
24
- audio = AudioSegment.from_file(audio_input)
25
- audio = audio.set_channels(1) # Convert to mono for consistency
26
- audio.export(temp_file.name, format="wav")
 
 
27
  audio_file_path = temp_file.name
28
 
29
- # Debug: Check if the WAV file is valid
30
- if os.path.getsize(audio_file_path) == 0:
31
- raise ValueError("The converted WAV file is empty. The input audio may be corrupted.")
32
-
33
- # Transcribe the WAV file using pocketsphinx (offline)
34
- with sr.AudioFile(audio_file_path) as source:
35
- audio = recognizer.record(source)
36
- try:
37
- transcription = recognizer.recognize_sphinx(audio) # Use pocketsphinx for offline transcription
38
- except sr.UnknownValueError:
39
- transcription = "Could not understand the audio."
40
- except sr.RequestError as e:
41
- transcription = f"Transcription failed: {str(e)}"
 
 
42
 
43
  # Detect language
44
  try:
 
1
  import gradio as gr
 
2
  from pydub import AudioSegment
3
  import tempfile
4
  from langdetect import detect
5
  import os
6
  import asyncio
7
+ import torch
8
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
9
  from telegram import Update
10
  from telegram.ext import Application, CommandHandler, MessageHandler, filters, ContextTypes
11
 
12
  # Telegram bot token (to be set via Hugging Face Space secrets)
13
  TELEGRAM_BOT_TOKEN = os.getenv("TELEGRAM_BOT_TOKEN")
14
 
15
+ # Load the Hugging Face model and processor for transcription
16
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
17
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
18
+
19
  # Process audio and transcribe
20
  def process_audio(audio_input):
21
+ # Convert all audio inputs to WAV format with 16kHz sample rate (required by wav2vec2)
 
 
22
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
23
  if isinstance(audio_input, tuple): # Recorded audio (sample_rate, numpy_array)
24
  sample_rate, audio_data = audio_input
25
+ audio_segment = AudioSegment(audio_data, sample_rate=sample_rate, frame_rate=sample_rate, channels=1)
26
  else: # Uploaded audio file (file path or Telegram file)
27
+ audio_segment = AudioSegment.from_file(audio_input).set_channels(1)
28
+
29
+ # Resample to 16kHz (required by wav2vec2)
30
+ audio_segment = audio_segment.set_frame_rate(16000)
31
+ audio_segment.export(temp_file.name, format="wav")
32
  audio_file_path = temp_file.name
33
 
34
+ # Load the WAV file for transcription
35
+ import soundfile as sf
36
+ audio_data, sample_rate = sf.read(audio_file_path)
37
+ assert sample_rate == 16000, "Sample rate must be 16kHz"
38
+
39
+ # Preprocess audio for the model
40
+ inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt", padding=True)
41
+
42
+ # Perform transcription
43
+ with torch.no_grad():
44
+ logits = model(inputs.input_values).logits
45
+
46
+ # Decode the logits to text
47
+ predicted_ids = torch.argmax(logits, dim=-1)
48
+ transcription = processor.batch_decode(predicted_ids)[0]
49
 
50
  # Detect language
51
  try: