EngTrainer / app.py
Anupam007's picture
Update app.py
48e44ca verified
import os
import gradio as gr
import torch
import numpy as np
import tempfile
import librosa
import re
from gtts import gTTS
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
import whisper
import sentencepiece # Ensure SentencePiece is imported
# Ensure correct device setting
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load models with error handling
try:
whisper_model = whisper.load_model("small")
except Exception as e:
print(f"Failed to load Whisper model: {e}")
whisper_model = None
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)
try:
sentiment_analyzer = pipeline("text-classification",
model="distilbert-base-uncased-finetuned-sst-2-english",
device=0 if device == "cuda" else -1)
except Exception as e:
print(f"Failed to load sentiment analyzer: {e}")
sentiment_analyzer = None
def speech_to_text(audio_path):
if not whisper_model:
return "Whisper model is not loaded."
try:
result = whisper_model.transcribe(audio_path)
return result["text"].strip()
except Exception as e:
return f"Speech recognition error: {e}"
def process_audio(audio_path):
if not audio_path or not os.path.exists(audio_path):
return "Error: No valid audio file provided.", "", "", "", "", None
try:
original_text = speech_to_text(audio_path)
corrected_text = original_text # Placeholder for grammar correction
return original_text, corrected_text, "", "", "", None
except Exception as e:
return f"Processing error: {e}", "", "", "", "", None
def create_interface():
with gr.Blocks() as app:
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Record your speech")
output_text = gr.Textbox(label="Recognized Text")
submit_btn = gr.Button("Analyze Speech")
submit_btn.click(process_audio, inputs=[audio_input], outputs=[output_text])
return app
if __name__ == "__main__":
app = create_interface()
app.launch(server_port=int(os.getenv("PORT", 7860)), server_name="0.0.0.0")