Update app.py
Browse files
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch
|
|
| 3 |
import torchaudio
|
| 4 |
import gradio as gr
|
| 5 |
import spaces
|
| 6 |
-
# Import the correct AutoModel class for the task
|
| 7 |
from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
|
| 8 |
|
| 9 |
DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
|
|
@@ -20,18 +19,17 @@ asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(d
|
|
| 20 |
asr_model.eval()
|
| 21 |
print("✅ ASR Model loaded.")
|
| 22 |
|
| 23 |
-
# Language Identification (LID) Model
|
| 24 |
print("\nLoading Language ID model (MMS-LID-1024)...")
|
| 25 |
lid_model_id = "facebook/mms-lid-1024"
|
| 26 |
lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
|
| 27 |
-
# Load the model with its audio classification head
|
| 28 |
-
lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
|
| 29 |
lid_model.eval()
|
| 30 |
print("✅ Language ID Model loaded.")
|
| 31 |
|
| 32 |
|
| 33 |
# --- Language Mappings ---
|
| 34 |
-
# Maps the LID model's output code to the ASR model's code
|
| 35 |
LID_TO_ASR_LANG_MAP = {
|
| 36 |
"asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
|
| 37 |
"guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
|
|
@@ -41,7 +39,6 @@ LID_TO_ASR_LANG_MAP = {
|
|
| 41 |
"tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
|
| 42 |
}
|
| 43 |
|
| 44 |
-
# Maps the ASR model's code back to a full name for display
|
| 45 |
ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
|
| 46 |
|
| 47 |
|
|
@@ -51,7 +48,6 @@ def transcribe_audio_with_lid(audio_path):
|
|
| 51 |
return "Please provide an audio file.", "", ""
|
| 52 |
|
| 53 |
try:
|
| 54 |
-
# Load and preprocess audio
|
| 55 |
waveform, sr = torchaudio.load(audio_path)
|
| 56 |
waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
|
| 57 |
except Exception as e:
|
|
@@ -63,8 +59,9 @@ def transcribe_audio_with_lid(audio_path):
|
|
| 63 |
with torch.no_grad():
|
| 64 |
outputs = lid_model(**inputs)
|
| 65 |
|
| 66 |
-
#
|
| 67 |
-
|
|
|
|
| 68 |
detected_lid_code = lid_model.config.id2label[predicted_lid_id]
|
| 69 |
|
| 70 |
# 2. --- Map to ASR Language Code ---
|
|
|
|
| 3 |
import torchaudio
|
| 4 |
import gradio as gr
|
| 5 |
import spaces
|
|
|
|
| 6 |
from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
|
| 7 |
|
| 8 |
DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
|
|
|
|
| 19 |
asr_model.eval()
|
| 20 |
print("✅ ASR Model loaded.")
|
| 21 |
|
| 22 |
+
# Language Identification (LID) Model - Using your specified model
|
| 23 |
print("\nLoading Language ID model (MMS-LID-1024)...")
|
| 24 |
lid_model_id = "facebook/mms-lid-1024"
|
| 25 |
lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
|
| 26 |
+
# Load the model with its audio classification head
|
| 27 |
+
lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
|
| 28 |
lid_model.eval()
|
| 29 |
print("✅ Language ID Model loaded.")
|
| 30 |
|
| 31 |
|
| 32 |
# --- Language Mappings ---
|
|
|
|
| 33 |
LID_TO_ASR_LANG_MAP = {
|
| 34 |
"asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
|
| 35 |
"guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
|
|
|
|
| 39 |
"tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
|
| 40 |
}
|
| 41 |
|
|
|
|
| 42 |
ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
|
| 43 |
|
| 44 |
|
|
|
|
| 48 |
return "Please provide an audio file.", "", ""
|
| 49 |
|
| 50 |
try:
|
|
|
|
| 51 |
waveform, sr = torchaudio.load(audio_path)
|
| 52 |
waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
|
| 53 |
except Exception as e:
|
|
|
|
| 59 |
with torch.no_grad():
|
| 60 |
outputs = lid_model(**inputs)
|
| 61 |
|
| 62 |
+
# CORRECTED: Access logits as the first element of the output tuple
|
| 63 |
+
logits = outputs[0]
|
| 64 |
+
predicted_lid_id = logits.argmax(-1).item()
|
| 65 |
detected_lid_code = lid_model.config.id2label[predicted_lid_id]
|
| 66 |
|
| 67 |
# 2. --- Map to ASR Language Code ---
|