AvtnshM commited on
Commit
46398ef
·
verified ·
1 Parent(s): a6ba767
Files changed (1) hide show
  1. app.py +95 -49
app.py CHANGED
@@ -1,7 +1,12 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
- from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, AutoModelForCTC
 
 
 
 
 
5
  import librosa
6
  import numpy as np
7
  from jiwer import wer, cer
@@ -12,19 +17,19 @@ MODEL_CONFIGS = {
12
  "AudioX-North (Jivi AI)": {
13
  "repo": "jiviai/audioX-north-v1",
14
  "model_type": "seq2seq",
15
- "description": "Supports Hindi, Gujarati, Marathi"
16
  },
17
  "IndicConformer (AI4Bharat)": {
18
  "repo": "ai4bharat/indic-conformer-600m-multilingual",
19
  "model_type": "ctc_rnnt",
20
  "description": "Supports 22 Indian languages",
21
- "trust_remote_code": True
22
  },
23
  "MMS (Facebook)": {
24
- "repo": "facebook/mms-1b",
25
  "model_type": "ctc",
26
- "description": "Supports over 1,400 languages (fine-tuning recommended)"
27
- }
28
  }
29
 
30
  # Load model and processor
@@ -33,73 +38,103 @@ def load_model_and_processor(model_name):
33
  repo = config["repo"]
34
  model_type = config["model_type"]
35
  trust_remote_code = config.get("trust_remote_code", False)
36
-
37
  try:
38
- processor = AutoProcessor.from_pretrained(repo, trust_remote_code=trust_remote_code)
39
- if model_type == "seq2seq":
40
- model = AutoModelForSpeechSeq2Seq.from_pretrained(repo, trust_remote_code=trust_remote_code)
41
- else: # ctc or ctc_rnnt
42
- model = AutoModelForCTC.from_pretrained(repo, trust_remote_code=trust_remote_code)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  return model, processor, model_type
44
  except Exception as e:
45
  return None, None, f"Error loading model: {str(e)}"
46
 
 
47
  # Compute metrics (WER, CER, RTF)
48
  def compute_metrics(reference, hypothesis, audio_duration):
49
  if not reference or not hypothesis:
50
  return None, None, None
51
  try:
52
- # Normalize text for better WER/CER calculation (e.g., remove extra spaces, handle numbers)
53
- reference = reference.strip().replace(" ", "").lower()
54
- hypothesis = hypothesis.strip().replace(" ", "").lower()
55
  wer_score = wer(reference, hypothesis)
56
  cer_score = cer(reference, hypothesis)
57
- rtf = (time.time() - start_time) / audio_duration if 'start_time' in globals() else None
 
 
 
 
58
  return wer_score, cer_score, rtf
59
- except Exception as e:
60
- return None, None, f"Error computing metrics: {str(e)}"
61
 
 
 
62
  def transcribe_audio(audio_file, model_name, reference_text=""):
63
  if not audio_file:
64
- return "Please upload an audio file.", None, None, None
65
-
66
  # Load model and processor
67
  model, processor, model_type = load_model_and_processor(model_name)
68
  if isinstance(model_type, str) and model_type.startswith("Error"):
69
- return model_type, None, None, None
70
-
71
  try:
72
  # Load and preprocess audio
73
  audio, sr = librosa.load(audio_file, sr=16000)
74
  audio_duration = len(audio) / sr
75
-
76
- # Process audio
77
  inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
78
- input_features = inputs["input_features"]
79
-
80
- # Measure processing time for RTF
81
  global start_time
82
  start_time = time.time()
 
83
  with torch.no_grad():
84
  if model_type == "seq2seq":
 
85
  outputs = model.generate(input_features)
86
- else: # ctc or ctc_rnnt
87
- outputs = model(input_features).logits
88
- outputs = torch.argmax(outputs, dim=-1)
89
-
90
- # Decode transcription
91
- transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
92
-
93
- # Compute metrics if reference text is provided
94
- wer_score, cer_score, rtf = None, None, None
95
- if reference_text:
96
- wer_score, cer_score, rtf = compute_metrics(reference_text, transcription, audio_duration)
97
- if isinstance(rtf, str):
98
- rtf = None # Handle error case
99
-
100
- return transcription, wer_score, cer_score, rtf
 
 
 
 
 
 
 
 
 
 
101
  except Exception as e:
102
- return f"Error during transcription: {str(e)}", None, None, None
 
103
 
104
  # Gradio interface
105
  def create_interface():
@@ -107,21 +142,32 @@ def create_interface():
107
  return gr.Interface(
108
  fn=transcribe_audio,
109
  inputs=[
110
- gr.Audio(type="filepath", label="Upload Audio File (16kHz recommended)"),
111
- gr.Dropdown(choices=model_choices, label="Select Model", value=model_choices[0]),
112
- gr.Textbox(label="Reference Text (Optional for WER/CER)", placeholder="Enter or paste ground truth text here", lines=3)
 
 
 
 
 
 
 
 
 
 
113
  ],
114
  outputs=[
115
- gr.Textbox(label="Transcription"),
116
  gr.Textbox(label="WER"),
117
  gr.Textbox(label="CER"),
118
- gr.Textbox(label="RTF")
119
  ],
120
  title="Multilingual Speech-to-Text with Metrics",
121
  description="Upload an audio file, select a model, and optionally provide reference text to compute WER, CER, and RTF.",
122
- allow_flagging="never"
123
  )
124
 
 
125
  if __name__ == "__main__":
126
  iface = create_interface()
127
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ from transformers import (
5
+ AutoModelForSpeechSeq2Seq,
6
+ AutoProcessor,
7
+ AutoModelForCTC,
8
+ AutoModel,
9
+ )
10
  import librosa
11
  import numpy as np
12
  from jiwer import wer, cer
 
17
  "AudioX-North (Jivi AI)": {
18
  "repo": "jiviai/audioX-north-v1",
19
  "model_type": "seq2seq",
20
+ "description": "Supports Hindi, Gujarati, Marathi",
21
  },
22
  "IndicConformer (AI4Bharat)": {
23
  "repo": "ai4bharat/indic-conformer-600m-multilingual",
24
  "model_type": "ctc_rnnt",
25
  "description": "Supports 22 Indian languages",
26
+ "trust_remote_code": True,
27
  },
28
  "MMS (Facebook)": {
29
+ "repo": "facebook/mms-1b-all", # fixed repo
30
  "model_type": "ctc",
31
+ "description": "Supports over 1,400 languages (fine-tuning recommended)",
32
+ },
33
  }
34
 
35
  # Load model and processor
 
38
  repo = config["repo"]
39
  model_type = config["model_type"]
40
  trust_remote_code = config.get("trust_remote_code", False)
41
+
42
  try:
43
+ if model_name == "IndicConformer (AI4Bharat)":
44
+ model = AutoModel.from_pretrained(repo, trust_remote_code=True)
45
+ processor = AutoProcessor.from_pretrained(repo, trust_remote_code=True)
46
+ elif model_name == "MMS (Facebook)":
47
+ model = AutoModelForCTC.from_pretrained(repo)
48
+ processor = AutoProcessor.from_pretrained(repo)
49
+ else: # AudioX-North
50
+ processor = AutoProcessor.from_pretrained(
51
+ repo, trust_remote_code=trust_remote_code
52
+ )
53
+ if model_type == "seq2seq":
54
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
55
+ repo, trust_remote_code=trust_remote_code
56
+ )
57
+ else:
58
+ model = AutoModelForCTC.from_pretrained(
59
+ repo, trust_remote_code=trust_remote_code
60
+ )
61
+
62
  return model, processor, model_type
63
  except Exception as e:
64
  return None, None, f"Error loading model: {str(e)}"
65
 
66
+
67
  # Compute metrics (WER, CER, RTF)
68
  def compute_metrics(reference, hypothesis, audio_duration):
69
  if not reference or not hypothesis:
70
  return None, None, None
71
  try:
72
+ reference = reference.strip().lower()
73
+ hypothesis = hypothesis.strip().lower()
 
74
  wer_score = wer(reference, hypothesis)
75
  cer_score = cer(reference, hypothesis)
76
+ rtf = (
77
+ (time.time() - start_time) / audio_duration
78
+ if "start_time" in globals()
79
+ else None
80
+ )
81
  return wer_score, cer_score, rtf
82
+ except Exception:
83
+ return None, None, None
84
 
85
+
86
+ # Main transcription function
87
  def transcribe_audio(audio_file, model_name, reference_text=""):
88
  if not audio_file:
89
+ return "Please upload an audio file.", "", "", ""
90
+
91
  # Load model and processor
92
  model, processor, model_type = load_model_and_processor(model_name)
93
  if isinstance(model_type, str) and model_type.startswith("Error"):
94
+ return model_type, "", "", ""
95
+
96
  try:
97
  # Load and preprocess audio
98
  audio, sr = librosa.load(audio_file, sr=16000)
99
  audio_duration = len(audio) / sr
100
+
 
101
  inputs = processor(audio, sampling_rate=16000, return_tensors="pt")
102
+
 
 
103
  global start_time
104
  start_time = time.time()
105
+
106
  with torch.no_grad():
107
  if model_type == "seq2seq":
108
+ input_features = inputs["input_features"]
109
  outputs = model.generate(input_features)
110
+ transcription = processor.batch_decode(
111
+ outputs, skip_special_tokens=True
112
+ )[0]
113
+ else: # CTC or RNNT
114
+ input_values = inputs["input_values"]
115
+ logits = model(input_values).logits
116
+ predicted_ids = torch.argmax(logits, dim=-1)
117
+ transcription = processor.batch_decode(
118
+ predicted_ids, skip_special_tokens=True
119
+ )[0]
120
+
121
+ # Compute metrics
122
+ wer_score, cer_score, rtf = "", "", ""
123
+ if reference_text and transcription:
124
+ wer_score, cer_score, rtf = compute_metrics(
125
+ reference_text, transcription, audio_duration
126
+ )
127
+ if wer_score is None:
128
+ wer_score = ""
129
+ if cer_score is None:
130
+ cer_score = ""
131
+ if rtf is None:
132
+ rtf = ""
133
+
134
+ return transcription, str(wer_score), str(cer_score), str(rtf)
135
  except Exception as e:
136
+ return f"Error during transcription: {str(e)}", "", "", ""
137
+
138
 
139
  # Gradio interface
140
  def create_interface():
 
142
  return gr.Interface(
143
  fn=transcribe_audio,
144
  inputs=[
145
+ gr.Audio(
146
+ type="filepath", label="Upload Audio File (16kHz recommended)"
147
+ ),
148
+ gr.Dropdown(
149
+ choices=model_choices,
150
+ label="Select Model",
151
+ value=model_choices[0],
152
+ ),
153
+ gr.Textbox(
154
+ label="Reference Text (Optional for WER/CER)",
155
+ placeholder="Enter or paste ground truth text here",
156
+ lines=3,
157
+ ),
158
  ],
159
  outputs=[
160
+ gr.Textbox(label="Transcription", show_copy_button=True),
161
  gr.Textbox(label="WER"),
162
  gr.Textbox(label="CER"),
163
+ gr.Textbox(label="RTF"),
164
  ],
165
  title="Multilingual Speech-to-Text with Metrics",
166
  description="Upload an audio file, select a model, and optionally provide reference text to compute WER, CER, and RTF.",
167
+ allow_flagging="never",
168
  )
169
 
170
+
171
  if __name__ == "__main__":
172
  iface = create_interface()
173
  iface.launch()