Noumida commited on
Commit
2fa52c3
·
verified ·
1 Parent(s): 3627a6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -9
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  import torchaudio
4
  import gradio as gr
5
  import spaces
6
- # Import the correct AutoModel class for the task
7
  from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
8
 
9
  DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
@@ -20,18 +19,17 @@ asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(d
20
  asr_model.eval()
21
  print("✅ ASR Model loaded.")
22
 
23
- # Language Identification (LID) Model
24
  print("\nLoading Language ID model (MMS-LID-1024)...")
25
  lid_model_id = "facebook/mms-lid-1024"
26
  lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
27
- # Load the model with its audio classification head to get logits
28
- lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device) # <-- THIS LINE IS UPDATED
29
  lid_model.eval()
30
  print("✅ Language ID Model loaded.")
31
 
32
 
33
  # --- Language Mappings ---
34
- # Maps the LID model's output code to the ASR model's code
35
  LID_TO_ASR_LANG_MAP = {
36
  "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
37
  "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
@@ -41,7 +39,6 @@ LID_TO_ASR_LANG_MAP = {
41
  "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
42
  }
43
 
44
- # Maps the ASR model's code back to a full name for display
45
  ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
46
 
47
 
@@ -51,7 +48,6 @@ def transcribe_audio_with_lid(audio_path):
51
  return "Please provide an audio file.", "", ""
52
 
53
  try:
54
- # Load and preprocess audio
55
  waveform, sr = torchaudio.load(audio_path)
56
  waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
57
  except Exception as e:
@@ -63,8 +59,9 @@ def transcribe_audio_with_lid(audio_path):
63
  with torch.no_grad():
64
  outputs = lid_model(**inputs)
65
 
66
- # This will now work because the output object has the .logits attribute
67
- predicted_lid_id = outputs.logits.argmax(-1).item()
 
68
  detected_lid_code = lid_model.config.id2label[predicted_lid_id]
69
 
70
  # 2. --- Map to ASR Language Code ---
 
3
  import torchaudio
4
  import gradio as gr
5
  import spaces
 
6
  from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
7
 
8
  DESCRIPTION = "IndicConformer ASR with Automatic Language Identification"
 
19
  asr_model.eval()
20
  print("✅ ASR Model loaded.")
21
 
22
+ # Language Identification (LID) Model - Using your specified model
23
  print("\nLoading Language ID model (MMS-LID-1024)...")
24
  lid_model_id = "facebook/mms-lid-1024"
25
  lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
26
+ # Load the model with its audio classification head
27
+ lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
28
  lid_model.eval()
29
  print("✅ Language ID Model loaded.")
30
 
31
 
32
  # --- Language Mappings ---
 
33
  LID_TO_ASR_LANG_MAP = {
34
  "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
35
  "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
 
39
  "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur"
40
  }
41
 
 
42
  ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu"}
43
 
44
 
 
48
  return "Please provide an audio file.", "", ""
49
 
50
  try:
 
51
  waveform, sr = torchaudio.load(audio_path)
52
  waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
53
  except Exception as e:
 
59
  with torch.no_grad():
60
  outputs = lid_model(**inputs)
61
 
62
+ # CORRECTED: Access logits as the first element of the output tuple
63
+ logits = outputs[0]
64
+ predicted_lid_id = logits.argmax(-1).item()
65
  detected_lid_code = lid_model.config.id2label[predicted_lid_id]
66
 
67
  # 2. --- Map to ASR Language Code ---