Devion333 commited on
Commit
fa5a1d8
·
verified ·
1 Parent(s): 269d84d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -33
app.py CHANGED
@@ -1,44 +1,166 @@
 
1
  import gradio as gr
2
- from transformers import pipeline
 
 
3
 
4
- # Load ASR pipeline
5
- asr_pipeline = pipeline(task="automatic-speech-recognition", model="import gradio as gr
6
- from transformers import pipeline
 
 
 
 
 
7
 
8
- # Load ASR pipeline
9
- asr_pipeline = pipeline(task="automatic-speech-recognition", model="Devion333/wav2vec2-xls-r-300m-dv")
10
- # 🔹 Replace with your own model if you trained one, e.g., "Devion333/whisper-small-dv-syn"
11
 
12
- def transcribe(audio):
13
- result = asr_pipeline(audio)
14
- return result["text"]
15
 
16
- # Build Gradio app
17
- gradio_app = gr.Interface(
18
- fn=transcribe,
19
- inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Speak or Upload Audio"),
20
- outputs=gr.Textbox(label="Transcription"),
21
- title="Speech-to-Text (ASR)",
22
- description="Upload an audio file or record speech and get the transcription using a Hugging Face ASR model."
23
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- if __name__ == "__main__":
26
- gradio_app.launch()
27
- ")
28
- # 🔹 Replace with your own model if you trained one, e.g., "Devion333/whisper-small-dv-syn"
 
 
 
 
 
29
 
30
- def transcribe(audio):
31
- result = asr_pipeline(audio)
32
- return result["text"]
33
 
34
- # Build Gradio app
35
- gradio_app = gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
36
  fn=transcribe,
37
- inputs=gr.Audio(sources=["upload", "microphone"], type="filepath", label="Speak or Upload Audio"),
38
- outputs=gr.Textbox(label="Transcription"),
39
- title="Speech-to-Text (ASR)",
40
- description="Upload an audio file or record speech and get the transcription using a Hugging Face ASR model."
 
 
41
  )
42
 
43
- if __name__ == "__main__":
44
- gradio_app.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
  import gradio as gr
3
+ import subprocess
4
+ import sys
5
+ import os
6
 
7
+ @spaces.GPU
8
+ def transcribe(audio_file):
9
+ try:
10
+ # Load audio file
11
+ waveform, sample_rate = torchaudio.load(audio_file)
12
+
13
+ # Move waveform to the correct device
14
+ waveform = waveform.to(device)
15
 
16
+ # Get the duration of the audio
17
+ duration = waveform.shape[1] / sample_rate
 
18
 
19
+ # Check if the audio is too short or too long
20
+ if duration < MIN_LENGTH or duration > MAX_LENGTH:
21
+ return f"Audio duration is too short or too long. Duration: {duration} seconds"
22
 
23
+ # Resample if necessary
24
+ if sample_rate != 16000:
25
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000).to(device)
26
+ waveform = resampler(waveform)
27
+
28
+ # Convert to mono if stereo
29
+ if waveform.shape[0] > 1:
30
+ waveform = waveform.mean(dim=0, keepdim=True)
31
+
32
+ # Move to CPU for numpy conversion
33
+ waveform = waveform.cpu()
34
+ audio_input = waveform.squeeze().numpy()
35
+
36
+ # Ensure audio input is float32
37
+ if audio_input.dtype != np.float32:
38
+ audio_input = audio_input.astype(np.float32)
39
+
40
+ # Process audio input
41
+ input_values = processor(
42
+ audio_input,
43
+ sampling_rate=16_000,
44
+ return_tensors="pt"
45
+ ).input_values.to(device)
46
+
47
+ # Convert to float16 if using CUDA
48
+ if torch_dtype == torch.float16:
49
+ input_values = input_values.half()
50
+
51
+ # Generate transcription
52
+ with torch.no_grad():
53
+ logits = model(input_values).logits
54
 
55
+ # Use language model for decoding
56
+ transcription = processor.decode(logits[0].cpu().numpy())
57
+
58
+ # Return the transcription in lowercase
59
+ print(transcription)
60
+ return transcription[0].lower()
61
+
62
+ except Exception as e:
63
+ return f"Error during transcription: {str(e)}"
64
 
65
+ # Create Gradio interface
 
 
66
 
67
+ css = """
68
+ .textbox1 textarea {
69
+ font-size: 18px !important;
70
+ font-family: 'MV_Faseyha', 'Faruma', 'A_Faruma' !important;
71
+ line-height: 1.8 !important;
72
+ }
73
+ .textbox2 textarea {
74
+ display: none;
75
+ }
76
+ """
77
+
78
+ demo = gr.Blocks(css=css)
79
+
80
+ tab_audio = gr.Interface(
81
  fn=transcribe,
82
+ inputs=[
83
+ gr.Audio(sources=["upload","microphone"], type="filepath", label="Audio"),
84
+ ],
85
+ outputs=gr.Textbox(label="Transcription", rtl=True, elem_classes="textbox1"),
86
+ title="Transcribe Dhivehi Audio",
87
+ allow_flagging="never",
88
  )
89
 
90
+ with demo:
91
+ gr.TabbedInterface([tab_audio], ["Audio"])
92
+
93
+
94
+ def install_requirements():
95
+ requirements_path = 'requirements.txt'
96
+
97
+ # Check if requirements.txt exists
98
+ if not os.path.exists(requirements_path):
99
+ print("Error: requirements.txt not found")
100
+ return False
101
+
102
+ try:
103
+ print("Installing requirements...")
104
+ # Using --no-cache-dir to avoid memory issues
105
+ subprocess.check_call([
106
+ sys.executable,
107
+ "-m",
108
+ "pip",
109
+ "install",
110
+ "-r",
111
+ requirements_path,
112
+ "--no-cache-dir"
113
+ ])
114
+ print("Successfully installed all requirements")
115
+ return True
116
+ except subprocess.CalledProcessError as e:
117
+ print(f"Error installing requirements: {e}")
118
+ return False
119
+ except Exception as e:
120
+ print(f"Unexpected error: {e}")
121
+ return False
122
+
123
+ # Launch the interface
124
+ if name == "main":
125
+ success = install_requirements()
126
+ if success:
127
+ print("All requirements installed successfully")
128
+
129
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2ProcessorWithLM
130
+ import torch
131
+ import torchaudio
132
+ import numpy as np
133
+
134
+ # Device and dtype configuration
135
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
136
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
137
+
138
+ MODEL_NAME = "Devion333/wav2vec2-xls-r-300m-dv" # Trained on common voice with ngram from news corpa
139
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/wav2vec2-large-mms-1b-cv" # Trained on Common Voice Data (Unknown Hours)
140
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/whisper-small-dv-syn-md" # Trained on 100% Synthetic Data (150 Hours)
141
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/whisper-small-cv" # Trained on Common Voice Data (Unknown Hours)
142
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/whisper-medium-dv-syn-md" # Trained on 100% Synthetic Data (150 Hours)
143
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/whisper-medium-cv" # Trained on Common Voice Data (Unknown Hours)
144
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/whisper-large-v3-dv-syn-md" # Trained on 100% Synthetic Data (150 Hours)
145
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/whisper-large-v3-cv" # Trained on Common Voice Data (Unknown Hours)
146
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/whisper-large-v3-calls-md" # Trained on phone calls (65 Hours)
147
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/wav2vec2-large-mms-1b-calls-md" # Trained on phone calls (65 Hours)
148
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/wav2vec2-large-xlsr-calls-md" # Trained on phone calls (23 Hours)
149
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/wav2vec2-large-xlsr-dv-syn-md" # Trained on 100% Synthetic Data (80 Hours)
150
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/dhivehi-asr-full-ctc" # Trained on multiple datasets (350+ Hours)
151
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/dhivehi-asr-full-ctc-v2" # Trained on multiple datasets (350+ Hours)
152
+ # MODEL_NAME = "/home/rusputin/lab/audio/fine-tunes/dhivehi-asr-full-whisper-v3" # Trained on multiple datasets (350+ Hours)
153
+
154
+ # Load model and processor with LM
155
+ processor = Wav2Vec2ProcessorWithLM.from_pretrained(MODEL_NAME)
156
+ model = Wav2Vec2ForCTC.from_pretrained(
157
+ MODEL_NAME,
158
+ torch_dtype=torch_dtype
159
+ ).to(device)
160
+
161
+ MAX_LENGTH = 120 # 2 minutes
162
+ MIN_LENGTH = 1 # 1 second
163
+
164
+ demo.launch()
165
+ else:
166
+ print("Failed to install some requirements")