audio2phoneme / app.py
IvanLayer7's picture
Update app.py
a7fcc90 verified
import gradio as gr
import torch
import librosa
import numpy as np
import os
from transformers import (
Wav2Vec2Processor,
Wav2Vec2ForCTC,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor
)
# Global variables for model and processor
processor = None
model = None
# Load access code from HuggingFace Secrets or environment variable
# In HuggingFace Spaces, add a secret named "ACCESS_CODE" with value "sensei"
ACCESS_CODE = os.environ.get("ACCESS_CODE")
def load_model():
"""Load the Wav2Vec2 phoneme model and processor."""
global processor, model
print("Loading Wav2Vec2 phoneme model...")
model_name = "facebook/wav2vec2-lv-60-espeak-cv-ft"
# Load feature extractor and tokenizer separately
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
# Create processor from components
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
# Load model
model = Wav2Vec2ForCTC.from_pretrained(model_name)
print("Model loaded successfully!")
def get_phoneme_transcription(audio_path):
"""
Get phoneme transcription from an audio file using Wav2Vec2 model.
Args:
audio_path (str): Path to the audio file
Returns:
tuple: (phoneme transcription, audio info)
"""
if processor is None or model is None:
return "Error: Model not loaded", ""
try:
# Load audio and resample to 16kHz
audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
# Get audio info
duration = len(audio_array) / sampling_rate
audio_info = f"πŸ“Š Audio Info:\n"
audio_info += f" β€’ Duration: {duration:.2f} seconds\n"
audio_info += f" β€’ Sample Rate: {sampling_rate} Hz\n"
audio_info += f" β€’ Samples: {len(audio_array)}\n"
# Process audio
input_values = processor(audio_array, return_tensors="pt", sampling_rate=16000).input_values
# Retrieve logits
with torch.no_grad():
logits = model(input_values).logits
# Take argmax and decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
phoneme_result = transcription[0]
# Format the output
result = f"πŸ—£οΈ Phoneme Transcription (IPA):\n\n{phoneme_result}\n\n"
result += f"πŸ“ Character count: {len(phoneme_result)}"
return result, audio_info
except Exception as e:
return f"❌ Error: {str(e)}", ""
def predict_phonemes(audio):
"""
Gradio interface function for phoneme prediction.
Args:
audio: Audio input from Gradio (can be file path or tuple)
Returns:
tuple: (phoneme transcription, audio info)
"""
if audio is None:
return "⚠️ Please upload or record an audio file", ""
# Handle different audio input formats from Gradio
if isinstance(audio, tuple):
# When audio is recorded, it comes as (sample_rate, audio_array)
sample_rate, audio_array = audio
# Save temporarily
import tempfile
import soundfile as sf
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
sf.write(tmp_file.name, audio_array, sample_rate)
audio_path = tmp_file.name
else:
# When audio is uploaded, it's a file path
audio_path = audio
return get_phoneme_transcription(audio_path)
def check_password(password):
"""Check if the entered password is correct."""
if password == ACCESS_CODE:
return {
login_page: gr.update(visible=False),
main_app: gr.update(visible=True)
}
else:
return {
login_page: gr.update(visible=True),
main_app: gr.update(visible=False)
}
# Load model on startup
load_model()
# Create Gradio interface with custom login
with gr.Blocks(title="Phoneme Transcription App", theme=gr.themes.Soft()) as demo:
# Login Page
with gr.Column(visible=True) as login_page:
gr.Markdown(
"""
# πŸ” Access Code Required
Please enter the access code to use the Phoneme Transcription App.
"""
)
with gr.Row():
with gr.Column(scale=1):
pass
with gr.Column(scale=2):
password_input = gr.Textbox(
label="Enter Access Code",
placeholder="Type access code here...",
type="password",
lines=1
)
login_btn = gr.Button("πŸš€ Access App", variant="primary", size="lg")
with gr.Column(scale=1):
pass
# Main App (hidden initially)
with gr.Column(visible=False) as main_app:
gr.Markdown(
"""
# Keywords Spotting (KWS)
## πŸŽ™οΈ Phoneme Transcription with Wav2Vec2
Upload or record audio to get phoneme transcription in IPA (International Phonetic Alphabet) format.
The first use could be slower than subsequent uses.
"""
)
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Upload or Record Audio"
)
predict_btn = gr.Button("πŸ” Get Phoneme Transcription", variant="primary", size="lg")
gr.Markdown(
"""
### πŸ’‘ Tips:
- Supports WAV, MP3, OGG, and other audio formats
- Audio will be automatically resampled to 16kHz
- Works best with clear speech (ideally)
- Supports multiple languages including Spanish
"""
)
with gr.Column():
phoneme_output = gr.Textbox(
label="Phoneme Transcription",
lines=8,
placeholder="Phoneme transcription will appear here..."
)
audio_info_output = gr.Textbox(
label="Audio Information",
lines=5,
placeholder="Audio details will appear here..."
)
# Connect the prediction button
predict_btn.click(
fn=predict_phonemes,
inputs=audio_input,
outputs=[phoneme_output, audio_info_output]
)
# Connect the login button
login_btn.click(
fn=check_password,
inputs=password_input,
outputs=[login_page, main_app]
)
# Launch the app without built-in authentication
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)