DineshJ96 commited on
Commit
dc7a247
·
1 Parent(s): 9626ac7

files added

Browse files
Files changed (2) hide show
  1. app.py +300 -0
  2. requirements.txt +21 -0
app.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ import gc
5
+ import json
6
+ import whisperx
7
+ from pyannote.audio import Pipeline
8
+ from huggingface_hub import HfFolder
9
+ from transformers import pipeline
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import io
13
+ import tempfile
14
+
15
+ # --- Configuration ---
16
+ HF_TOKEN = os.getenv("HF_TOKEN")
17
+
18
+ if not HF_TOKEN:
19
+ print("WARNING: HF_TOKEN environment variable not set. Please set it as a Space secret or directly for local testing.")
20
+ print("Visit https://huggingface.co/settings/tokens to create one and accept model conditions for pyannote/speaker-diarization, etc.")
21
+
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ compute_type = "float16" if device == "cuda" else "int8"
24
+ whisper_model_size = "medium" # 'large-v2' is best but most resource intensive.
25
+ # 'small' or 'medium' are better for free tiers.
26
+
27
+ # --- Global Models (loaded once) ---
28
+ whisper_model_global = None
29
+ diarize_pipeline_global = None
30
+ translation_pipeline_global = None
31
+
32
+ def load_all_models():
33
+ global whisper_model_global, diarize_pipeline_global, translation_pipeline_global
34
+
35
+ print(f"Loading WhisperX model ({whisper_model_size})...")
36
+ whisper_model_global = whisperx.load_model(whisper_model_size, device=device, compute_type=compute_type)
37
+
38
+ print("Loading Pyannote Diarization Pipeline...")
39
+ if not HF_TOKEN:
40
+ raise ValueError("Hugging Face token (HF_TOKEN) not set. Please set it as a Space secret.")
41
+ diarize_pipeline_global = whisperx.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)
42
+
43
+ print("Loading translation model (Helsinki-NLP/opus-mt-ta-en)...")
44
+ try:
45
+ translation_pipeline_global = pipeline(
46
+ "translation",
47
+ model="Helsinki-NLP/opus-mt-ta-en",
48
+ device=0 if device == "cuda" else -1
49
+ )
50
+ except Exception as e:
51
+ print(f"Could not load translation model: {e}")
52
+ translation_pipeline_global = None
53
+
54
+ # Load models when the Gradio app starts
55
+ load_all_models()
56
+
57
+ def convert_audio_for_whisper(audio_input):
58
+ """
59
+ Converts Gradio audio input (filepath or (sr, numpy_array)) to a 16kHz mono WAV file
60
+ that WhisperX expects. Returns the path to the temporary WAV file.
61
+ """
62
+ temp_wav_path = None
63
+
64
+ if isinstance(audio_input, str): # Filepath from gr.Audio(type="filepath")
65
+ input_filepath = audio_input
66
+ temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
67
+ try:
68
+ waveform, sample_rate = sf.read(input_filepath)
69
+ if waveform.ndim > 1:
70
+ waveform = waveform.mean(axis=1) # Convert to mono if stereo
71
+
72
+ # Resample only if necessary
73
+ if sample_rate != 16000:
74
+ print(f"Warning: Audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
75
+ # For high-quality resampling, you'd use torchaudio.transforms.Resample
76
+ # For simple cases, soundfile might handle basic resample on write,
77
+ # or WhisperX's load_audio does its own internal resampling.
78
+ # Explicitly loading/resampling here for robustness.
79
+ from torchaudio.transforms import Resample
80
+ waveform_tensor = torch.from_numpy(waveform).float()
81
+ resampler = Resample(orig_freq=sample_rate, new_freq=16000)
82
+ waveform = resampler(waveform_tensor).numpy()
83
+ sample_rate = 16000 # Update sample rate after resampling
84
+
85
+ sf.write(temp_wav_path, waveform, 16000, format='WAV', subtype='PCM_16')
86
+ return temp_wav_path
87
+ except Exception as e:
88
+ print(f"Error converting uploaded audio: {e}")
89
+ return None
90
+
91
+ elif isinstance(audio_input, tuple): # (sr, numpy_array) from gr.Audio(type="numpy") or microphone
92
+ sample_rate, numpy_array = audio_input
93
+
94
+ # Ensure it's mono
95
+ if numpy_array.ndim > 1:
96
+ numpy_array = numpy_array.mean(axis=1)
97
+
98
+ # Normalize to float32 if not already (soundfile expects this)
99
+ if numpy_array.dtype != np.float32:
100
+ numpy_array = numpy_array.astype(np.float32) / np.max(np.abs(numpy_array))
101
+
102
+ # Resample only if necessary for microphone input as well
103
+ if sample_rate != 16000:
104
+ print(f"Warning: Microphone audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
105
+ from torchaudio.transforms import Resample
106
+ waveform_tensor = torch.from_numpy(numpy_array).float()
107
+ resampler = Resample(orig_freq=sample_rate, new_freq=16000)
108
+ numpy_array = resampler(waveform_tensor).numpy()
109
+ sample_rate = 16000 # Update sample rate after resampling
110
+
111
+ temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
112
+ try:
113
+ sf.write(temp_wav_path, numpy_array, 16000, format='WAV', subtype='PCM_16') # Always write at 16kHz
114
+ return temp_wav_path
115
+ except Exception as e:
116
+ print(f"Error writing microphone audio to temp file: {e}")
117
+ return None
118
+
119
+ return None
120
+
121
+ def process_audio_for_web(audio_input):
122
+ """
123
+ Processes an audio input (from upload or microphone) for speaker diarization,
124
+ transcription, and translation.
125
+ """
126
+
127
+ if audio_input is None:
128
+ return "Please upload an audio file or record from microphone.", "", "", None
129
+
130
+ audio_file_path = convert_audio_for_whisper(audio_input)
131
+ if not audio_file_path:
132
+ return "Error: Could not process audio input. Please ensure it's a valid audio format.", "", "", None
133
+
134
+ print(f"Processing audio from temp file: {audio_file_path}")
135
+
136
+ try:
137
+ audio = whisperx.load_audio(audio_file_path)
138
+
139
+ # 1. Transcribe
140
+ print("Transcribing audio...")
141
+ transcription_result = whisper_model_global.transcribe(audio, batch_size=1)
142
+ detected_language = transcription_result["language"]
143
+ print(f"Detected overall language: {detected_language}")
144
+
145
+ # 2. Align
146
+ print("Aligning transcription with audio...")
147
+ align_model_local, metadata = whisperx.load_align_model(language_code=detected_language, device=device)
148
+ transcription_result = whisperx.align(transcription_result["segments"], align_model_local, audio, device, return_char_alignments=False)
149
+ del align_model_local
150
+ gc.collect()
151
+ if device == "cuda":
152
+ torch.cuda.empty_cache()
153
+
154
+ # 3. Diarize
155
+ print("Performing speaker diarization...")
156
+ diarize_segments = diarize_pipeline_global(audio_file_path)
157
+ final_result = whisperx.assign_word_speakers(diarize_segments, transcription_result)
158
+
159
+ speaker_transcripts_raw = {}
160
+ # Prepare for display in dianzed_transcription_output
161
+ diarized_display_lines = []
162
+
163
+ for segment in final_result["segments"]:
164
+ speaker_id = segment.get("speaker", "UNKNOWN_SPEAKER")
165
+ text = segment["text"].strip()
166
+ start = segment["start"]
167
+ end = segment["end"]
168
+
169
+ if speaker_id not in speaker_transcripts_raw:
170
+ speaker_transcripts_raw[speaker_id] = []
171
+ speaker_transcripts_raw[speaker_id].append({
172
+ "start": start,
173
+ "end": end,
174
+ "text": text
175
+ })
176
+ diarized_display_lines.append(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker_id}: {text}")
177
+
178
+ full_diarized_text_str = "\n".join(diarized_display_lines)
179
+
180
+ # 4. Translate
181
+ translated_display_lines = []
182
+ if translation_pipeline_global:
183
+ translated_speaker_data = {} # To hold translated segments per speaker
184
+ for speaker, segments in speaker_transcripts_raw.items():
185
+ translated_speaker_data[speaker] = [] # Initialize for current speaker
186
+
187
+ translated_display_lines.append(f"\n--- Speaker {speaker} (Original & Translated) ---")
188
+ for seg in segments:
189
+ original_text = seg['text']
190
+ translated_text_output = original_text
191
+
192
+ is_tamil_char_present = any(ord(char) > 0x0B80 and ord(char) < 0x0BFF for char in original_text)
193
+
194
+ if original_text and (detected_language == 'ta' or is_tamil_char_present):
195
+ try:
196
+ translated_result = translation_pipeline_global(original_text, src_lang="ta", tgt_lang="en")
197
+ translated_text_output = translated_result[0]['translation_text']
198
+ except Exception as e:
199
+ print(f"Error translating segment for speaker {speaker}: '{original_text}'. Error: {e}. Keeping original text.")
200
+
201
+ translated_speaker_data[speaker].append({
202
+ "start": seg['start'],
203
+ "end": seg['end'],
204
+ "original_text": original_text,
205
+ "translated_text": translated_text_output
206
+ })
207
+ translated_display_lines.append(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {original_text}")
208
+ translated_display_lines.append(f" Translated: {translated_text_output}")
209
+
210
+ translated_output_str = "\n".join(translated_display_lines)
211
+ else:
212
+ translated_output_str = "Translation model not loaded. Skipping translation."
213
+
214
+ # Create a temporary file for download
215
+ output_filename = tempfile.NamedTemporaryFile(suffix=".txt", delete=False).name
216
+ with open(output_filename, "w", encoding="utf-8") as f:
217
+ f.write("--- Speaker-wise Original Transcription ---\n\n")
218
+ # Write original transcription per speaker
219
+ for speaker, segments in speaker_transcripts_raw.items():
220
+ f.write(f"\n### Speaker {speaker} ###\n")
221
+ for seg in segments:
222
+ f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text']}\n")
223
+
224
+ f.write("\n\n--- Speaker-wise Translated Transcription (to English) ---\n\n")
225
+ # Write translated transcription per speaker
226
+ if translation_pipeline_global and 'translated_speaker_data' in locals():
227
+ for speaker, segments in translated_speaker_data.items():
228
+ f.write(f"\n### Speaker {speaker} ###\n")
229
+ for seg in segments:
230
+ f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {seg['original_text']}\n")
231
+ f.write(f" Translated: {seg['translated_text']}\n")
232
+ else:
233
+ f.write("Translation output not available or translation model not loaded.\n")
234
+
235
+ f.write(f"\n\nOverall Detected Language: {detected_language}")
236
+
237
+ # Clean up the temporary audio file
238
+ os.unlink(audio_file_path)
239
+
240
+ return full_diarized_text_str, translated_output_str, f"Detected overall language: {detected_language}", output_filename
241
+
242
+ except Exception as e:
243
+ import traceback
244
+ error_message = f"An error occurred: {e}\n{traceback.format_exc()}"
245
+ print(error_message)
246
+ # Clean up temp audio file even on error
247
+ if audio_file_path and os.path.exists(audio_file_path):
248
+ os.unlink(audio_file_path)
249
+ return error_message, "", "", None
250
+
251
+ # --- Gradio Interface ---
252
+ with gr.Blocks(title="Language-Agnostic Speaker Diarization, Transcription, and Translation") as demo:
253
+ gr.Markdown(
254
+ """
255
+ # Language-Agnostic Speaker Diarization, Transcription, and Translation
256
+ Upload an audio file (WAV, MP3, etc.) or record directly from your microphone.
257
+ The system will identify speakers, transcribe their speech (in detected language),
258
+ and provide an English translation for relevant segments.
259
+ """
260
+ )
261
+
262
+ with gr.Row():
263
+ audio_input = gr.Audio(
264
+ type="filepath",
265
+ sources=["upload", "microphone"],
266
+ label="Upload Audio File or Record from Microphone"
267
+ )
268
+
269
+ with gr.Row():
270
+ process_button = gr.Button("Process Audio", variant="primary")
271
+
272
+ with gr.Column():
273
+ detected_language_output = gr.Textbox(label="Detected Overall Language")
274
+ # Diarized Transcription will still be chronological with speaker labels
275
+ diarized_transcription_output = gr.Textbox(label="Diarized Transcription (Chronological with Speaker Labels)", lines=10, interactive=False)
276
+ # Translated transcription will now be clearly separated by speaker
277
+ translated_transcription_output = gr.Textbox(label="Translated Transcription (to English, per Speaker)", lines=10, interactive=False)
278
+
279
+ download_button = gr.File(label="Download Transcription (.txt)", interactive=False, visible=False)
280
+
281
+ process_button.click(
282
+ fn=process_audio_for_web,
283
+ inputs=audio_input,
284
+ outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button]
285
+ )
286
+
287
+ gr.Examples(
288
+ [
289
+ # Add paths to your example audio files here.
290
+ # These files must be present in your Hugging Face Space repository.
291
+ # For example, if you have 'sample_two_speakers.wav' in your repo:
292
+ # "sample_two_speakers.wav"
293
+ ],
294
+ inputs=audio_input,
295
+ outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button],
296
+ fn=process_audio_for_web,
297
+ cache_examples=False
298
+ )
299
+
300
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core AI/ML libraries
2
+ torch==2.5.1 # Updated to meet whisperx's minimum requirement
3
+ torchaudio==2.5.1 # Updated to match torch version
4
+ transformers
5
+ accelerate
6
+ sentencepiece
7
+
8
+ # Audio processing & Diarization
9
+ pyannote.audio
10
+ soundfile
11
+ ffmpeg-python
12
+
13
+ # Gradio for the UI
14
+ gradio
15
+
16
+ # WhisperX specific
17
+ whisperx @ git+https://github.com/m-bain/whisperX.git
18
+
19
+ # General utilities (often useful)
20
+ numpy
21
+ scipy