Spaces:
Running
Running
Initialized project directories, added requirements, and implemented core Gradio UI with lazy-loaded Wav2Vec2 inference
Browse files- app.py +4 -9
- app/asr_model.py +47 -5
- app/audio_processing.py +17 -3
- app/history.py +48 -4
- app/language_detection.py +16 -5
app.py
CHANGED
|
@@ -1,18 +1,18 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from app.asr_model import load_model, transcribe_audio
|
| 3 |
-
from app.language_detection import
|
| 4 |
from app.history import save_to_history, export_history
|
| 5 |
|
| 6 |
def process_audio(audio_path):
|
| 7 |
if audio_path is None:
|
| 8 |
return "No audio uploaded.", "Unknown"
|
| 9 |
|
| 10 |
-
# Optional: Detect Language
|
| 11 |
-
lang = detect_language(audio_path)
|
| 12 |
-
|
| 13 |
# Transcribe Speech
|
| 14 |
transcript = transcribe_audio(audio_path)
|
| 15 |
|
|
|
|
|
|
|
|
|
|
| 16 |
# Save History
|
| 17 |
save_to_history(audio_path, transcript, lang)
|
| 18 |
|
|
@@ -42,10 +42,5 @@ def create_ui():
|
|
| 42 |
return demo
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
| 45 |
-
# Pre-load model on start
|
| 46 |
-
print("Loading model...")
|
| 47 |
-
load_model()
|
| 48 |
-
print("Model loaded. Starting UI...")
|
| 49 |
-
|
| 50 |
demo = create_ui()
|
| 51 |
demo.launch()
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from app.asr_model import load_model, transcribe_audio
|
| 3 |
+
from app.language_detection import detect_language_from_text
|
| 4 |
from app.history import save_to_history, export_history
|
| 5 |
|
| 6 |
def process_audio(audio_path):
|
| 7 |
if audio_path is None:
|
| 8 |
return "No audio uploaded.", "Unknown"
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
# Transcribe Speech
|
| 11 |
transcript = transcribe_audio(audio_path)
|
| 12 |
|
| 13 |
+
# Detect Language from transcript
|
| 14 |
+
lang = detect_language_from_text(transcript)
|
| 15 |
+
|
| 16 |
# Save History
|
| 17 |
save_to_history(audio_path, transcript, lang)
|
| 18 |
|
|
|
|
| 42 |
return demo
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
demo = create_ui()
|
| 46 |
demo.launch()
|
app/asr_model.py
CHANGED
|
@@ -1,15 +1,57 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
Loads the Hugging Face Wav2Vec model and processor.
|
| 6 |
-
|
|
|
|
|
|
|
| 7 |
"""
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
def transcribe_audio(audio_filepath: str) -> str:
|
| 11 |
"""
|
| 12 |
Takes an audio filepath, processes it, and runs it through the Wav2Vec model
|
| 13 |
to return a text transcription.
|
| 14 |
"""
|
| 15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
| 3 |
+
from app.audio_processing import load_and_resample
|
| 4 |
|
| 5 |
+
# Global variables to hold the model and processor
|
| 6 |
+
processor = None
|
| 7 |
+
model = None
|
| 8 |
+
|
| 9 |
+
def load_model(model_name: str = "facebook/wav2vec2-base-960h"):
|
| 10 |
"""
|
| 11 |
Loads the Hugging Face Wav2Vec model and processor.
|
| 12 |
+
Defaulting to English base model. For multilingual, consider models like:
|
| 13 |
+
- 'facebook/mms-1b-all'
|
| 14 |
+
- 'jonatasgrosman/wav2vec2-large-xlsr-53-english' (or other languages)
|
| 15 |
"""
|
| 16 |
+
global processor, model
|
| 17 |
+
print(f"Loading model {model_name}...")
|
| 18 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
| 19 |
+
model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
| 20 |
+
|
| 21 |
+
# Move to GPU if available (MPS for Apple Silicon)
|
| 22 |
+
if torch.backends.mps.is_available():
|
| 23 |
+
model.to("mps")
|
| 24 |
+
elif torch.cuda.is_available():
|
| 25 |
+
model.to("cuda")
|
| 26 |
|
| 27 |
def transcribe_audio(audio_filepath: str) -> str:
|
| 28 |
"""
|
| 29 |
Takes an audio filepath, processes it, and runs it through the Wav2Vec model
|
| 30 |
to return a text transcription.
|
| 31 |
"""
|
| 32 |
+
if model is None or processor is None:
|
| 33 |
+
load_model()
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
# 1. Load and resample audio to 16kHz
|
| 37 |
+
speech = load_and_resample(audio_filepath, target_sr=16000)
|
| 38 |
+
|
| 39 |
+
# 2. Prepare inputs
|
| 40 |
+
inputs = processor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
|
| 41 |
+
|
| 42 |
+
# Move inputs to the same device as model
|
| 43 |
+
device = next(model.parameters()).device
|
| 44 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 45 |
+
|
| 46 |
+
# 3. Perform inference
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
logits = model(**inputs).logits
|
| 49 |
+
|
| 50 |
+
# 4. Decode the output
|
| 51 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
| 52 |
+
transcription = processor.batch_decode(predicted_ids)[0]
|
| 53 |
+
|
| 54 |
+
return transcription.lower()
|
| 55 |
+
except Exception as e:
|
| 56 |
+
return f"Error during transcription: {str(e)}"
|
| 57 |
+
|
app/audio_processing.py
CHANGED
|
@@ -1,7 +1,21 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
|
| 3 |
-
def load_and_resample(audio_filepath: str, target_sr: int = 16000):
|
| 4 |
"""
|
| 5 |
Loads an audio file and resamples it to the target sample rate (default 16kHz for Wav2Vec).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
"""
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import librosa
|
| 2 |
+
import numpy as np
|
| 3 |
|
| 4 |
+
def load_and_resample(audio_filepath: str, target_sr: int = 16000) -> np.ndarray:
|
| 5 |
"""
|
| 6 |
Loads an audio file and resamples it to the target sample rate (default 16kHz for Wav2Vec).
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
audio_filepath (str): Path to the audio file.
|
| 10 |
+
target_sr (int): The sample rate required by the model.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
np.ndarray: The audio time series.
|
| 14 |
"""
|
| 15 |
+
try:
|
| 16 |
+
# librosa automatically resamples if sr is provided
|
| 17 |
+
speech, _ = librosa.load(audio_filepath, sr=target_sr)
|
| 18 |
+
return speech
|
| 19 |
+
except Exception as e:
|
| 20 |
+
raise RuntimeError(f"Error processing audio file {audio_filepath}: {e}")
|
| 21 |
+
|
app/history.py
CHANGED
|
@@ -1,13 +1,57 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
def save_to_history(audio_filepath: str, transcript: str, language: str):
|
| 4 |
"""
|
| 5 |
-
Saves the transcription data to a local JSON
|
| 6 |
"""
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def export_history(format: str = "csv"):
|
| 10 |
"""
|
| 11 |
Exports the saved history into a downloadable format.
|
|
|
|
| 12 |
"""
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import datetime
|
| 4 |
+
import csv
|
| 5 |
+
|
| 6 |
+
HISTORY_FILE = "data/history.json"
|
| 7 |
|
| 8 |
def save_to_history(audio_filepath: str, transcript: str, language: str):
|
| 9 |
"""
|
| 10 |
+
Saves the transcription data to a local JSON file in the data/ directory.
|
| 11 |
"""
|
| 12 |
+
os.makedirs(os.path.dirname(HISTORY_FILE), exist_ok=True)
|
| 13 |
+
|
| 14 |
+
entry = {
|
| 15 |
+
"timestamp": datetime.datetime.now().isoformat(),
|
| 16 |
+
"audio_file": os.path.basename(audio_filepath),
|
| 17 |
+
"language": language,
|
| 18 |
+
"transcript": transcript
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
history = []
|
| 22 |
+
if os.path.exists(HISTORY_FILE):
|
| 23 |
+
try:
|
| 24 |
+
with open(HISTORY_FILE, "r") as f:
|
| 25 |
+
history = json.load(f)
|
| 26 |
+
except Exception:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
history.append(entry)
|
| 30 |
+
|
| 31 |
+
with open(HISTORY_FILE, "w") as f:
|
| 32 |
+
json.dump(history, f, indent=4)
|
| 33 |
|
| 34 |
def export_history(format: str = "csv"):
|
| 35 |
"""
|
| 36 |
Exports the saved history into a downloadable format.
|
| 37 |
+
Returns the path to the exported file.
|
| 38 |
"""
|
| 39 |
+
if not os.path.exists(HISTORY_FILE):
|
| 40 |
+
return None
|
| 41 |
+
|
| 42 |
+
export_path = f"data/export_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
with open(HISTORY_FILE, "r") as f:
|
| 46 |
+
history = json.load(f)
|
| 47 |
+
|
| 48 |
+
with open(export_path, "w", newline='') as f:
|
| 49 |
+
writer = csv.DictWriter(f, fieldnames=["timestamp", "audio_file", "language", "transcript"])
|
| 50 |
+
writer.writeheader()
|
| 51 |
+
for row in history:
|
| 52 |
+
writer.writerow(row)
|
| 53 |
+
|
| 54 |
+
return export_path
|
| 55 |
+
except Exception as e:
|
| 56 |
+
print(f"Failed to export history: {e}")
|
| 57 |
+
return None
|
app/language_detection.py
CHANGED
|
@@ -1,8 +1,19 @@
|
|
| 1 |
-
|
| 2 |
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
-
|
| 6 |
-
|
| 7 |
"""
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langdetect import detect, DetectorFactory
|
| 2 |
|
| 3 |
+
# Ensure consistent results
|
| 4 |
+
DetectorFactory.seed = 0
|
| 5 |
+
|
| 6 |
+
def detect_language_from_text(text: str) -> str:
|
| 7 |
"""
|
| 8 |
+
Detects language based on the transcribed text.
|
| 9 |
+
Returns the ISO 639-1 language code (e.g., 'en', 'es', 'fr').
|
| 10 |
"""
|
| 11 |
+
if not text or len(text.strip()) < 2:
|
| 12 |
+
return "Unknown"
|
| 13 |
+
|
| 14 |
+
try:
|
| 15 |
+
lang = detect(text)
|
| 16 |
+
return lang
|
| 17 |
+
except Exception as e:
|
| 18 |
+
return "Unknown"
|
| 19 |
+
|