Noumida commited on
Commit
391a015
·
verified ·
1 Parent(s): ff42fba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -48
app.py CHANGED
@@ -3,25 +3,38 @@ import torch
3
  import torchaudio
4
  import gradio as gr
5
  import spaces
6
- from transformers import AutoModel, AutoProcessor, Wav2Vec2ForCTC
7
- import re
 
8
 
9
  DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
 
 
 
 
 
 
 
 
12
 
 
 
 
13
  print("Loading ASR model (IndicConformer)...")
14
  asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
15
  asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
16
  asr_model.eval()
17
- print(" ASR Model loaded.")
18
 
19
- print("\nLoading Language ID model (MMS-LID)...")
20
- lid_model_id = "facebook/mms-lid"
 
21
  lid_processor = AutoProcessor.from_pretrained(lid_model_id)
22
  lid_model = AutoModel.from_pretrained(lid_model_id).to(device)
23
  lid_model.eval()
24
- print(" Language ID Model loaded.")
25
 
26
 
27
  # --- Language Mappings ---
@@ -37,45 +50,8 @@ LID_TO_ASR_LANG_MAP = {
37
 
38
  # Maps the ASR model's code back to a full name for display
39
  ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
40
- import torch
41
- import torchaudio
42
- import gradio as gr
43
- import spaces
44
- from transformers import AutoModel, AutoProcessor, Wav2Vec2ForCTC
45
- import re
46
-
47
- DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
48
- device = "cuda" if torch.cuda.is_available() else "cpu"
49
-
50
- # --- ASR Model (The one we used before) ---
51
- print("Loading ASR model (IndicConformer)...")
52
- asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
53
- asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
54
- asr_model.eval()
55
- print(" ASR Model loaded.")
56
-
57
- # --- Language Identification (LID) Model ---
58
- print("\nLoading Language ID model (MMS-LID)...")
59
- lid_model_id = "facebook/mms-lid"
60
- lid_processor = AutoProcessor.from_pretrained(lid_model_id)
61
- lid_model = AutoModel.from_pretrained(lid_model_id).to(device)
62
- lid_model.eval()
63
- print(" Language ID Model loaded.")
64
 
65
 
66
- # --- Language Mappings ---
67
- # Maps the LID model's output code to the ASR model's code
68
- LID_TO_ASR_LANG_MAP = {
69
- "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
70
- "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
71
- "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
72
- "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
73
- "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
74
- "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
75
- }
76
-
77
- # Maps the ASR model's code back to a full name for display
78
- ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
79
  @spaces.GPU
80
  def transcribe_audio_with_lid(audio_path):
81
  if not audio_path:
@@ -84,7 +60,6 @@ def transcribe_audio_with_lid(audio_path):
84
  try:
85
  # Load and preprocess audio
86
  waveform, sr = torchaudio.load(audio_path)
87
- # Resample for both models
88
  waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
89
  except Exception as e:
90
  return f"Error loading audio: {e}", "", ""
@@ -95,9 +70,7 @@ def transcribe_audio_with_lid(audio_path):
95
  with torch.no_grad():
96
  outputs = lid_model(**inputs)
97
 
98
- # Get the top predicted language ID from the LID model
99
  predicted_lid_id = outputs.logits.argmax(-1).item()
100
- # The model.config.id2label gives us the language code like "hin_Deva"
101
  detected_lid_code = lid_model.config.id2label[predicted_lid_id]
102
 
103
  # 2. --- Map to ASR Language Code ---
@@ -111,7 +84,6 @@ def transcribe_audio_with_lid(audio_path):
111
 
112
  # 3. --- Transcription using the detected language ---
113
  with torch.no_grad():
114
- # Use the ASR model with the correctly identified language code
115
  transcription_ctc = asr_model(waveform_16k.to(device), asr_lang_code, "ctc")
116
  transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt")
117
 
@@ -120,7 +92,7 @@ def transcribe_audio_with_lid(audio_path):
120
 
121
  return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()
122
 
123
- # Gradio UI (no major changes needed here)
124
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
125
  gr.Markdown(f"## {DESCRIPTION}")
126
  gr.Markdown("Upload or record audio in any of the 22 supported Indian languages. The app will automatically detect the language and provide the transcription.")
 
3
  import torchaudio
4
  import gradio as gr
5
  import spaces
6
+ from transformers import AutoModel, AutoProcessor
7
+ from huggingface_hub import login
8
+ from google.colab import userdata # Or use os.environ if not in Colab
9
 
10
  DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
 
13
+ # --- Authentication Step ---
14
+ try:
15
+ # Fetches the token from secrets (e.g., in Colab or Hugging Face Spaces)
16
+ HF_TOKEN = userdata.get('HF_TOKEN')
17
+ login(token=HF_TOKEN)
18
+ print("✅ Successfully logged into Hugging Face Hub.")
19
+ except Exception as e:
20
+ print(f"⚠️ Could not log into Hugging Face Hub. Please ensure HF_TOKEN is set correctly. Error: {e}")
21
 
22
+ # --- Model Loading ---
23
+
24
+ # ASR Model (IndicConformer)
25
  print("Loading ASR model (IndicConformer)...")
26
  asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
27
  asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
28
  asr_model.eval()
29
+ print(" ASR Model loaded.")
30
 
31
+ # Language Identification (LID) Model - Updated to the user-specified version
32
+ print("\nLoading Language ID model (MMS-LID-1024)...")
33
+ lid_model_id = "facebook/mms-lid-1024" # <-- THIS LINE HAS BEEN UPDATED
34
  lid_processor = AutoProcessor.from_pretrained(lid_model_id)
35
  lid_model = AutoModel.from_pretrained(lid_model_id).to(device)
36
  lid_model.eval()
37
+ print(" Language ID Model loaded.")
38
 
39
 
40
  # --- Language Mappings ---
 
50
 
51
  # Maps the ASR model's code back to a full name for display
52
  ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @spaces.GPU
56
  def transcribe_audio_with_lid(audio_path):
57
  if not audio_path:
 
60
  try:
61
  # Load and preprocess audio
62
  waveform, sr = torchaudio.load(audio_path)
 
63
  waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
64
  except Exception as e:
65
  return f"Error loading audio: {e}", "", ""
 
70
  with torch.no_grad():
71
  outputs = lid_model(**inputs)
72
 
 
73
  predicted_lid_id = outputs.logits.argmax(-1).item()
 
74
  detected_lid_code = lid_model.config.id2label[predicted_lid_id]
75
 
76
  # 2. --- Map to ASR Language Code ---
 
84
 
85
  # 3. --- Transcription using the detected language ---
86
  with torch.no_grad():
 
87
  transcription_ctc = asr_model(waveform_16k.to(device), asr_lang_code, "ctc")
88
  transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt")
89
 
 
92
 
93
  return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()
94
 
95
+ # --- Gradio UI ---
96
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
97
  gr.Markdown(f"## {DESCRIPTION}")
98
  gr.Markdown("Upload or record audio in any of the 22 supported Indian languages. The app will automatically detect the language and provide the transcription.")