Noumida commited on
Commit
dca9a76
·
verified ·
1 Parent(s): 2fa52c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -9,21 +9,15 @@ DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # --- Model Loading ---
12
- # NOTE: If running on a Space with a HF_TOKEN secret,
13
- # transformers will automatically use it for gated models.
14
-
15
- # ASR Model (IndicConformer)
16
  print("Loading ASR model (IndicConformer)...")
17
  asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
18
  asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
19
  asr_model.eval()
20
  print("✅ ASR Model loaded.")
21
 
22
- # Language Identification (LID) Model - Using your specified model
23
  print("\nLoading Language ID model (MMS-LID-1024)...")
24
  lid_model_id = "facebook/mms-lid-1024"
25
  lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
26
- # Load the model with its audio classification head
27
  lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
28
  lid_model.eval()
29
  print("✅ Language ID Model loaded.")
@@ -36,7 +30,8 @@ LID_TO_ASR_LANG_MAP = {
36
  "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
37
  "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
38
  "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
39
- "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
 
40
  }
41
 
42
  ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
@@ -54,26 +49,26 @@ def transcribe_audio_with_lid(audio_path):
54
  return f"Error loading audio: {e}", "", ""
55
 
56
  try:
57
- # 1. --- Language Identification ---
58
  inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(device)
59
  with torch.no_grad():
60
  outputs = lid_model(**inputs)
61
-
62
- # CORRECTED: Access logits as the first element of the output tuple
63
  logits = outputs[0]
64
  predicted_lid_id = logits.argmax(-1).item()
65
  detected_lid_code = lid_model.config.id2label[predicted_lid_id]
66
 
67
- # 2. --- Map to ASR Language Code ---
68
  asr_lang_code = LID_TO_ASR_LANG_MAP.get(detected_lid_code)
69
 
70
  if not asr_lang_code:
71
- detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model."
72
- return detected_lang_str, "N/A", "N/A"
 
 
 
 
73
 
74
  detected_lang_str = f"Detected Language: {ASR_CODE_TO_NAME.get(asr_lang_code, 'Unknown')}"
75
 
76
- # 3. --- Transcription using the detected language ---
77
  with torch.no_grad():
78
  transcription_ctc = asr_model(waveform_16k.to(device), asr_lang_code, "ctc")
79
  transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt")
@@ -83,6 +78,7 @@ def transcribe_audio_with_lid(audio_path):
83
 
84
  return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()
85
 
 
86
  # --- Gradio UI ---
87
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
88
  gr.Markdown(f"## {DESCRIPTION}")
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # --- Model Loading ---
 
 
 
 
12
  print("Loading ASR model (IndicConformer)...")
13
  asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
14
  asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
15
  asr_model.eval()
16
  print("✅ ASR Model loaded.")
17
 
 
18
  print("\nLoading Language ID model (MMS-LID-1024)...")
19
  lid_model_id = "facebook/mms-lid-1024"
20
  lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
 
21
  lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
22
  lid_model.eval()
23
  print("✅ Language ID Model loaded.")
 
30
  "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
31
  "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
32
  "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
33
+ "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur",
34
+ "pan": "pa" # <-- ADDED THIS FIX FOR PUNJABI
35
  }
36
 
37
  ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
 
49
  return f"Error loading audio: {e}", "", ""
50
 
51
  try:
 
52
  inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(device)
53
  with torch.no_grad():
54
  outputs = lid_model(**inputs)
55
+
 
56
  logits = outputs[0]
57
  predicted_lid_id = logits.argmax(-1).item()
58
  detected_lid_code = lid_model.config.id2label[predicted_lid_id]
59
 
 
60
  asr_lang_code = LID_TO_ASR_LANG_MAP.get(detected_lid_code)
61
 
62
  if not asr_lang_code:
63
+ # Fallback for simple codes like 'pan' from other LID models
64
+ if detected_lid_code in LID_TO_ASR_LANG_MAP:
65
+ asr_lang_code = LID_TO_ASR_LANG_MAP[detected_lid_code]
66
+ else:
67
+ detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model."
68
+ return detected_lang_str, "N/A", "N/A"
69
 
70
  detected_lang_str = f"Detected Language: {ASR_CODE_TO_NAME.get(asr_lang_code, 'Unknown')}"
71
 
 
72
  with torch.no_grad():
73
  transcription_ctc = asr_model(waveform_16k.to(device), asr_lang_code, "ctc")
74
  transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt")
 
78
 
79
  return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()
80
 
81
+
82
  # --- Gradio UI ---
83
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
84
  gr.Markdown(f"## {DESCRIPTION}")