Spaces:
Running
Running
app.py file name changed to gradio_ui.py which shows history tab and transcribe button
Browse files- app.py +0 -46
- app/asr_model.py +10 -5
app.py
DELETED
|
@@ -1,46 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
from app.asr_model import load_model, transcribe_audio
|
| 3 |
-
from app.language_detection import detect_language_from_text
|
| 4 |
-
from app.history import save_to_history, export_history
|
| 5 |
-
|
| 6 |
-
def process_audio(audio_path):
|
| 7 |
-
if audio_path is None:
|
| 8 |
-
return "No audio uploaded.", "Unknown"
|
| 9 |
-
|
| 10 |
-
# Transcribe Speech
|
| 11 |
-
transcript = transcribe_audio(audio_path)
|
| 12 |
-
|
| 13 |
-
# Detect Language from transcript
|
| 14 |
-
lang = detect_language_from_text(transcript)
|
| 15 |
-
|
| 16 |
-
# Save History
|
| 17 |
-
save_to_history(audio_path, transcript, lang)
|
| 18 |
-
|
| 19 |
-
return transcript, lang
|
| 20 |
-
|
| 21 |
-
def create_ui():
|
| 22 |
-
with gr.Blocks(title="Multilingual ASR") as demo:
|
| 23 |
-
gr.Markdown("# Multilingual Automatic Speech Recognition")
|
| 24 |
-
gr.Markdown("Upload an audio file to get a text transcription using Wav2Vec.")
|
| 25 |
-
|
| 26 |
-
with gr.Row():
|
| 27 |
-
with gr.Column():
|
| 28 |
-
audio_input = gr.Audio(type="filepath", label="Upload Audio")
|
| 29 |
-
transcribe_btn = gr.Button("Transcribe")
|
| 30 |
-
|
| 31 |
-
with gr.Column():
|
| 32 |
-
lang_output = gr.Textbox(label="Detected Language")
|
| 33 |
-
transcript_output = gr.Textbox(label="Transcription", lines=10)
|
| 34 |
-
download_btn = gr.File(label="Download Transcript (Coming Soon)")
|
| 35 |
-
|
| 36 |
-
transcribe_btn.click(
|
| 37 |
-
fn=process_audio,
|
| 38 |
-
inputs=audio_input,
|
| 39 |
-
outputs=[transcript_output, lang_output]
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
return demo
|
| 43 |
-
|
| 44 |
-
if __name__ == "__main__":
|
| 45 |
-
demo = create_ui()
|
| 46 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app/asr_model.py
CHANGED
|
@@ -18,11 +18,16 @@ def load_model(model_name: str = "facebook/wav2vec2-base-960h"):
|
|
| 18 |
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
| 19 |
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
| 20 |
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def transcribe_audio(audio_filepath: str) -> str:
|
| 28 |
"""
|
|
|
|
| 18 |
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
| 19 |
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
| 20 |
|
| 21 |
+
# Temporary fix: Force CPU usage.
|
| 22 |
+
# The MPS (Apple Silicon GPU) backend in PyTorch currently has known bugs with Wav2Vec2
|
| 23 |
+
# that can cause the forward pass to freeze indefinitely.
|
| 24 |
+
device = "cpu"
|
| 25 |
+
# if torch.backends.mps.is_available():
|
| 26 |
+
# device = "mps"
|
| 27 |
+
# elif torch.cuda.is_available():
|
| 28 |
+
# device = "cuda"
|
| 29 |
+
|
| 30 |
+
model.to(device)
|
| 31 |
|
| 32 |
def transcribe_audio(audio_filepath: str) -> str:
|
| 33 |
"""
|