Noumida commited on
Commit
3627a6f
·
verified ·
1 Parent(s): 354de8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -41
app.py CHANGED
@@ -3,26 +3,29 @@ import torch
3
  import torchaudio
4
  import gradio as gr
5
  import spaces
6
- from transformers import AutoModel, Wav2Vec2Processor, Wav2Vec2ForCTC
 
7
 
8
  DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
  # --- Model Loading ---
 
 
12
 
13
- # ASR Model (IndicConformer) - This is a custom model, so AutoModel is appropriate
14
  print("Loading ASR model (IndicConformer)...")
15
  asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
16
  asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
17
  asr_model.eval()
18
  print("✅ ASR Model loaded.")
19
 
20
- # Language Identification (LID) Model - Standard Wav2Vec2 architecture
21
- print("\nLoading Language ID model (MMS-LID)...")
22
- lid_model_id = "facebook/wav2vec2-base-960h-lid" # Using the official LID fine-tune
23
- # The processor bundles the feature extractor and the tokenizer/decoder for this model
24
- lid_processor = Wav2Vec2Processor.from_pretrained(lid_model_id)
25
- lid_model = Wav2Vec2ForCTC.from_pretrained(lid_model_id).to(device)
26
  lid_model.eval()
27
  print("✅ Language ID Model loaded.")
28
 
@@ -30,36 +33,16 @@ print("✅ Language ID Model loaded.")
30
  # --- Language Mappings ---
31
  # Maps the LID model's output code to the ASR model's code
32
  LID_TO_ASR_LANG_MAP = {
33
- "asm": "as", "ben": "bn", "brx": "br", "doi": "doi", "guj": "gu", "hin": "hi",
34
- "kan": "kn", "kas": "ks", "kok": "kok", "mai": "mai", "mal": "ml", "mni": "mni",
35
- "mar": "mr", "nep": "ne", "ori": "or", "pan": "pa", "san": "sa", "sat": "sat",
36
- "snd": "sd", "tam": "ta", "tel": "te", "urd": "ur",
37
- # Adding English as the LID model supports it
38
- "eng": "en"
39
  }
40
 
41
  # Maps the ASR model's code back to a full name for display
42
- ASR_CODE_TO_NAME = {
43
- "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati",
44
- "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili",
45
- "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia",
46
- "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil",
47
- "te": "Telugu", "ur": "Urdu", "en": "English"
48
- }
49
-
50
-
51
- # --- Core Logic Functions ---
52
-
53
- def identify_language(waveform_16k):
54
- """Identifies the language from an audio waveform using the LID model."""
55
- input_values = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").input_values
56
- with torch.no_grad():
57
- logits = lid_model(input_values.to(device)).logits
58
-
59
- predicted_ids = torch.argmax(logits, dim=-1)
60
- # The 'decode' function for this specific LID model gives the language code
61
- language_code = lid_processor.decode(predicted_ids)
62
- return language_code.strip()
63
 
64
 
65
  @spaces.GPU
@@ -68,6 +51,7 @@ def transcribe_audio_with_lid(audio_path):
68
  return "Please provide an audio file.", "", ""
69
 
70
  try:
 
71
  waveform, sr = torchaudio.load(audio_path)
72
  waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
73
  except Exception as e:
@@ -75,14 +59,18 @@ def transcribe_audio_with_lid(audio_path):
75
 
76
  try:
77
  # 1. --- Language Identification ---
78
- # The LID model's output is a simple language code (e.g., "hin" for Hindi)
79
- detected_lid_code = identify_language(waveform_16k)
 
 
 
 
 
80
 
81
  # 2. --- Map to ASR Language Code ---
82
- # Note: We are simplifying the mapping as the new LID model gives simpler codes
83
- asr_lang_code = detected_lid_code.lower()
84
 
85
- if asr_lang_code not in ASR_CODE_TO_NAME:
86
  detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model."
87
  return detected_lang_str, "N/A", "N/A"
88
 
@@ -101,7 +89,7 @@ def transcribe_audio_with_lid(audio_path):
101
  # --- Gradio UI ---
102
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
103
  gr.Markdown(f"## {DESCRIPTION}")
104
- gr.Markdown("Upload or record audio in any of the supported languages. The app will automatically detect the language and provide the transcription.")
105
 
106
  with gr.Row():
107
  with gr.Column(scale=1):
 
3
  import torchaudio
4
  import gradio as gr
5
  import spaces
6
+ # Import the correct AutoModel class for the task
7
+ from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
8
 
9
  DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
  # --- Model Loading ---
13
+ # NOTE: If running on a Space with a HF_TOKEN secret,
14
+ # transformers will automatically use it for gated models.
15
 
16
+ # ASR Model (IndicConformer)
17
  print("Loading ASR model (IndicConformer)...")
18
  asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
19
  asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
20
  asr_model.eval()
21
  print("✅ ASR Model loaded.")
22
 
23
+ # Language Identification (LID) Model
24
+ print("\nLoading Language ID model (MMS-LID-1024)...")
25
+ lid_model_id = "facebook/mms-lid-1024"
26
+ lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
27
+ # Load the model with its audio classification head to get logits
28
+ lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device) # <-- THIS LINE IS UPDATED
29
  lid_model.eval()
30
  print("✅ Language ID Model loaded.")
31
 
 
33
  # --- Language Mappings ---
34
  # Maps the LID model's output code to the ASR model's code
35
  LID_TO_ASR_LANG_MAP = {
36
+ "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
37
+ "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
38
+ "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
39
+ "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
40
+ "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
41
+ "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
42
  }
43
 
44
  # Maps the ASR model's code back to a full name for display
45
+ ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
 
48
  @spaces.GPU
 
51
  return "Please provide an audio file.", "", ""
52
 
53
  try:
54
+ # Load and preprocess audio
55
  waveform, sr = torchaudio.load(audio_path)
56
  waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
57
  except Exception as e:
 
59
 
60
  try:
61
  # 1. --- Language Identification ---
62
+ inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(device)
63
+ with torch.no_grad():
64
+ outputs = lid_model(**inputs)
65
+
66
+ # This will now work because the output object has the .logits attribute
67
+ predicted_lid_id = outputs.logits.argmax(-1).item()
68
+ detected_lid_code = lid_model.config.id2label[predicted_lid_id]
69
 
70
  # 2. --- Map to ASR Language Code ---
71
+ asr_lang_code = LID_TO_ASR_LANG_MAP.get(detected_lid_code)
 
72
 
73
+ if not asr_lang_code:
74
  detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model."
75
  return detected_lang_str, "N/A", "N/A"
76
 
 
89
  # --- Gradio UI ---
90
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
91
  gr.Markdown(f"## {DESCRIPTION}")
92
+ gr.Markdown("Upload or record audio in any of the 22 supported Indian languages. The app will automatically detect the language and provide the transcription.")
93
 
94
  with gr.Row():
95
  with gr.Column(scale=1):