AvtnshM commited on
Commit
a60f163
·
verified ·
1 Parent(s): 0bfe1fb
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -17,7 +17,8 @@ MODEL_CONFIGS = {
17
  "IndicConformer (AI4Bharat)": {
18
  "repo": "ai4bharat/indic-conformer-600m-multilingual",
19
  "model_type": "ctc_rnnt",
20
- "description": "Supports 22 Indian languages"
 
21
  },
22
  "MMS (Facebook)": {
23
  "repo": "facebook/mms-1b",
@@ -31,13 +32,14 @@ def load_model_and_processor(model_name):
31
  config = MODEL_CONFIGS[model_name]
32
  repo = config["repo"]
33
  model_type = config["model_type"]
 
34
 
35
  try:
36
- processor = AutoProcessor.from_pretrained(repo)
37
  if model_type == "seq2seq":
38
- model = AutoModelForSpeechSeq2Seq.from_pretrained(repo)
39
  else: # ctc or ctc_rnnt
40
- model = AutoModelForCTC.from_pretrained(repo)
41
  return model, processor, model_type
42
  except Exception as e:
43
  return None, None, f"Error loading model: {str(e)}"
@@ -47,9 +49,12 @@ def compute_metrics(reference, hypothesis, audio_duration):
47
  if not reference or not hypothesis:
48
  return None, None, None
49
  try:
 
 
 
50
  wer_score = wer(reference, hypothesis)
51
  cer_score = cer(reference, hypothesis)
52
- rtf = audio_duration / time.time() # Simplified; actual RTF needs processing time
53
  return wer_score, cer_score, rtf
54
  except Exception as e:
55
  return None, None, f"Error computing metrics: {str(e)}"
@@ -73,6 +78,7 @@ def transcribe_audio(audio_file, model_name, reference_text=""):
73
  input_features = inputs["input_features"]
74
 
75
  # Measure processing time for RTF
 
76
  start_time = time.time()
77
  with torch.no_grad():
78
  if model_type == "seq2seq":
@@ -87,10 +93,9 @@ def transcribe_audio(audio_file, model_name, reference_text=""):
87
  # Compute metrics if reference text is provided
88
  wer_score, cer_score, rtf = None, None, None
89
  if reference_text:
90
- wer_score, cer_score, rtf_error = compute_metrics(reference_text, transcription, audio_duration)
91
- if isinstance(rtf_error, str):
92
- return transcription, wer_score, cer_score, rtf_error
93
- rtf = (time.time() - start_time) / audio_duration # Actual RTF
94
 
95
  return transcription, wer_score, cer_score, rtf
96
  except Exception as e:
@@ -104,7 +109,7 @@ def create_interface():
104
  inputs=[
105
  gr.Audio(type="filepath", label="Upload Audio File (16kHz recommended)"),
106
  gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0]),
107
- gr.Textbox(label="Reference Text (Optional for WER/CER)", placeholder="Enter ground truth text here")
108
  ],
109
  outputs=[
110
  gr.Textbox(label="Transcription"),
 
17
  "IndicConformer (AI4Bharat)": {
18
  "repo": "ai4bharat/indic-conformer-600m-multilingual",
19
  "model_type": "ctc_rnnt",
20
+ "description": "Supports 22 Indian languages",
21
+ "trust_remote_code": True
22
  },
23
  "MMS (Facebook)": {
24
  "repo": "facebook/mms-1b",
 
32
  config = MODEL_CONFIGS[model_name]
33
  repo = config["repo"]
34
  model_type = config["model_type"]
35
+ trust_remote_code = config.get("trust_remote_code", False)
36
 
37
  try:
38
+ processor = AutoProcessor.from_pretrained(repo, trust_remote_code=trust_remote_code)
39
  if model_type == "seq2seq":
40
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(repo, trust_remote_code=trust_remote_code)
41
  else: # ctc or ctc_rnnt
42
+ model = AutoModelForCTC.from_pretrained(repo, trust_remote_code=trust_remote_code)
43
  return model, processor, model_type
44
  except Exception as e:
45
  return None, None, f"Error loading model: {str(e)}"
 
49
  if not reference or not hypothesis:
50
  return None, None, None
51
  try:
52
+ # Normalize text for better WER/CER calculation (e.g., remove extra spaces, handle numbers)
53
+ reference = reference.strip().replace(" ", "").lower()
54
+ hypothesis = hypothesis.strip().replace(" ", "").lower()
55
  wer_score = wer(reference, hypothesis)
56
  cer_score = cer(reference, hypothesis)
57
+ rtf = (time.time() - start_time) / audio_duration if 'start_time' in globals() else None
58
  return wer_score, cer_score, rtf
59
  except Exception as e:
60
  return None, None, f"Error computing metrics: {str(e)}"
 
78
  input_features = inputs["input_features"]
79
 
80
  # Measure processing time for RTF
81
+ global start_time
82
  start_time = time.time()
83
  with torch.no_grad():
84
  if model_type == "seq2seq":
 
93
  # Compute metrics if reference text is provided
94
  wer_score, cer_score, rtf = None, None, None
95
  if reference_text:
96
+ wer_score, cer_score, rtf = compute_metrics(reference_text, transcription, audio_duration)
97
+ if isinstance(rtf, str):
98
+ rtf = None # Handle error case
 
99
 
100
  return transcription, wer_score, cer_score, rtf
101
  except Exception as e:
 
109
  inputs=[
110
  gr.Audio(type="filepath", label="Upload Audio File (16kHz recommended)"),
111
  gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0]),
112
+ gr.Textbox(label="Reference Text (Optional for WER/CER)", placeholder="Enter or paste ground truth text here", lines=3)
113
  ],
114
  outputs=[
115
  gr.Textbox(label="Transcription"),