File size: 14,924 Bytes
dc7a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0baba61
dc7a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d926715
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275f6d2
d926715
 
 
dc7a247
 
 
 
 
 
 
 
 
 
d926715
dc7a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d926715
 
dc7a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d926715
dc7a247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import gradio as gr
import os
import torch
import gc
import json
import whisperx
from pyannote.audio import Pipeline
from huggingface_hub import HfFolder
from transformers import pipeline
import numpy as np
import soundfile as sf
import io
import tempfile

# --- Configuration ---
HF_TOKEN = os.getenv("HF_TOKEN")

if not HF_TOKEN:
    print("WARNING: HF_TOKEN environment variable not set. Please set it as a Space secret or directly for local testing.")
    print("Visit https://huggingface.co/settings/tokens to create one and accept model conditions for pyannote/speaker-diarization, etc.")

device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
whisper_model_size = "medium" # 'large-v2' is best but most resource intensive.
                               # 'small' or 'medium' are better for free tiers.

# --- Global Models (loaded once) ---
whisper_model_global = None
diarize_pipeline_global = None
translation_pipeline_global = None

def load_all_models():
    global whisper_model_global, diarize_pipeline_global, translation_pipeline_global

    print(f"Loading WhisperX model ({whisper_model_size})...")
    whisper_model_global = whisperx.load_model(whisper_model_size, device=device, compute_type=compute_type)

    print("Loading Pyannote Diarization Pipeline...")
    if not HF_TOKEN:
        raise ValueError("Hugging Face token (HF_TOKEN) not set. Please set it as a Space secret.")
    diarize_pipeline_global = whisperx.diarize.DiarizationPipeline(use_auth_token=HF_TOKEN, device=device)

    print("Loading translation model (Helsinki-NLP/opus-mt-ta-en)...")
    try:
        translation_pipeline_global = pipeline(
            "translation",
            model="Helsinki-NLP/opus-mt-ta-en",
            device=0 if device == "cuda" else -1
        )
    except Exception as e:
        print(f"Could not load translation model: {e}")
        translation_pipeline_global = None

# Load models when the Gradio app starts
load_all_models()

def convert_audio_for_whisper(audio_input):
    """
    Converts Gradio audio input (filepath or (sr, numpy_array)) to a 16kHz mono WAV file
    that WhisperX expects. Returns the path to the temporary WAV file.
    """
    temp_wav_path = None
    
    if isinstance(audio_input, str): # Filepath from gr.Audio(type="filepath")
        input_filepath = audio_input
        temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
        try:
            waveform, sample_rate = sf.read(input_filepath)
            if waveform.ndim > 1:
                waveform = waveform.mean(axis=1) # Convert to mono if stereo
            
            # Resample only if necessary
            if sample_rate != 16000:
                print(f"Warning: Audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
                # For high-quality resampling, you'd use torchaudio.transforms.Resample
                # For simple cases, soundfile might handle basic resample on write,
                # or WhisperX's load_audio does its own internal resampling.
                # Explicitly loading/resampling here for robustness.
                from torchaudio.transforms import Resample
                waveform_tensor = torch.from_numpy(waveform).float()
                resampler = Resample(orig_freq=sample_rate, new_freq=16000)
                waveform = resampler(waveform_tensor).numpy()
                sample_rate = 16000 # Update sample rate after resampling
            
            sf.write(temp_wav_path, waveform, 16000, format='WAV', subtype='PCM_16')
            return temp_wav_path
        except Exception as e:
            print(f"Error converting uploaded audio: {e}")
            return None
            
    elif isinstance(audio_input, tuple): # (sr, numpy_array) from gr.Audio(type="numpy") or microphone
        sample_rate, numpy_array = audio_input
        
        # Ensure it's mono
        if numpy_array.ndim > 1:
            numpy_array = numpy_array.mean(axis=1)
        
        # Normalize to float32 if not already (soundfile expects this)
        if numpy_array.dtype != np.float32:
            numpy_array = numpy_array.astype(np.float32) / np.max(np.abs(numpy_array))
        
        # Resample only if necessary for microphone input as well
        if sample_rate != 16000:
            print(f"Warning: Microphone audio sample rate is {sample_rate}Hz. Resampling to 16kHz.")
            from torchaudio.transforms import Resample
            waveform_tensor = torch.from_numpy(numpy_array).float()
            resampler = Resample(orig_freq=sample_rate, new_freq=16000)
            numpy_array = resampler(waveform_tensor).numpy()
            sample_rate = 16000 # Update sample rate after resampling
            
        temp_wav_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
        try:
            sf.write(temp_wav_path, numpy_array, 16000, format='WAV', subtype='PCM_16') # Always write at 16kHz
            return temp_wav_path
        except Exception as e:
            print(f"Error writing microphone audio to temp file: {e}")
            return None
    
    return None

def process_audio_for_web(audio_input):
    """
    Processes an audio input (from upload or microphone) for speaker diarization,
    transcription, and translation.
    """
    
    if audio_input is None:
        return "Please upload an audio file or record from microphone.", "", "", None

    audio_file_path = convert_audio_for_whisper(audio_input)
    if not audio_file_path:
        return "Error: Could not process audio input. Please ensure it's a valid audio format.", "", "", None
    
    print(f"Processing audio from temp file: {audio_file_path}")

    try:
        audio = whisperx.load_audio(audio_file_path)

        # 1. Transcribe
        print("Transcribing audio...")
        transcription_result = whisper_model_global.transcribe(audio, batch_size=1) 
        detected_language = transcription_result["language"]
        print(f"Detected overall language: {detected_language}")

        # 2. Align
        print("Aligning transcription with audio...")
        align_model_local = None # Initialize to None to prevent UnboundLocalError in outer except
        try:
            # Load the alignment model based on the detected language
            # The 'device' parameter is passed here, not to whisperx.align
            align_model_local, metadata = whisperx.load_align_model(language_code=detected_language, device=device)
        except Exception as e:
            # Handle cases where the alignment model for the detected language cannot be loaded
            print(f"Error loading alignment model for language '{detected_language}': {e}")
            import traceback
            print(traceback.format_exc())
            # Provide a user-friendly message, possibly suggesting supported languages
            return f"Error: Could not load alignment model for language '{detected_language}'. Alignment is typically supported for English, French, German, Spanish, Italian, Japanese, Chinese, Dutch, and Portuguese. Details: {e}", "", "", None

        # Perform alignment using the loaded model
        # Removed 'device' from here as the model itself is already on the correct device
        transcription_result = whisperx.align(transcription_result["segments"], align_model_local, audio, device, return_char_alignments=False)
        
        # Removed the duplicate whisperx.align call and 'del align_model_local'
        # as it can cause issues if an error occurs later.
        gc.collect()
        if device == "cuda":
            torch.cuda.empty_cache()

        # 3. Diarize
        print("Performing speaker diarization...")
        diarize_segments = diarize_pipeline_global(audio_file_path)
        final_result = whisperx.assign_word_speakers(diarize_segments, transcription_result)

        speaker_transcripts_raw = {}
        # Prepare for display in diarized_transcription_output
        diarized_display_lines = [] 

        for segment in final_result["segments"]:
            speaker_id = segment.get("speaker", "UNKNOWN_SPEAKER")
            text = segment["text"].strip()
            start = segment["start"]
            end = segment["end"]

            if speaker_id not in speaker_transcripts_raw:
                speaker_transcripts_raw[speaker_id] = []
            speaker_transcripts_raw[speaker_id].append({
                "start": start,
                "end": end,
                "text": text
            })
            diarized_display_lines.append(f"[{start:.2f}s - {end:.2f}s] Speaker {speaker_id}: {text}")

        full_diarized_text_str = "\n".join(diarized_display_lines)

        # 4. Translate
        translated_display_lines = []
        if translation_pipeline_global:
            translated_speaker_data = {} # To hold translated segments per speaker
            for speaker, segments in speaker_transcripts_raw.items():
                translated_speaker_data[speaker] = [] # Initialize for current speaker
                
                translated_display_lines.append(f"\n--- Speaker {speaker} (Original & Translated) ---")
                for seg in segments:
                    original_text = seg['text']
                    translated_text_output = original_text

                    is_tamil_char_present = any(ord(char) > 0x0B80 and ord(char) < 0x0BFF for char in original_text)
                    
                    if original_text and (detected_language == 'ta' or is_tamil_char_present): 
                        try:
                            translated_result = translation_pipeline_global(original_text, src_lang="ta", tgt_lang="en")
                            translated_text_output = translated_result[0]['translation_text']
                        except Exception as e:
                            print(f"Error translating segment for speaker {speaker}: '{original_text}'. Error: {e}. Keeping original text.")
                    
                    translated_speaker_data[speaker].append({
                        "start": seg['start'],
                        "end": seg['end'],
                        "original_text": original_text,
                        "translated_text": translated_text_output
                    })
                    translated_display_lines.append(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {original_text}")
                    translated_display_lines.append(f"           Translated: {translated_text_output}")
                
            translated_output_str = "\n".join(translated_display_lines)
        else:
            translated_output_str = "Translation model not loaded. Skipping translation."

        # Create a temporary file for download
        output_filename = tempfile.NamedTemporaryFile(suffix=".txt", delete=False).name
        with open(output_filename, "w", encoding="utf-8") as f:
            f.write("--- Speaker-wise Original Transcription ---\n\n")
            # Write original transcription per speaker
            for speaker, segments in speaker_transcripts_raw.items():
                f.write(f"\n### Speaker {speaker} ###\n")
                for seg in segments:
                    f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] {seg['text']}\n")
            
            f.write("\n\n--- Speaker-wise Translated Transcription (to English) ---\n\n")
            # Write translated transcription per speaker
            if translation_pipeline_global and 'translated_speaker_data' in locals():
                for speaker, segments in translated_speaker_data.items():
                    f.write(f"\n### Speaker {speaker} ###\n")
                    for seg in segments:
                        f.write(f"[{seg['start']:.2f}s - {seg['end']:.2f}s] Original: {seg['original_text']}\n")
                        f.write(f"           Translated: {seg['translated_text']}\n")
            else:
                f.write("Translation output not available or translation model not loaded.\n")
            
            f.write(f"\n\nOverall Detected Language: {detected_language}")

        # Clean up the temporary audio file
        os.unlink(audio_file_path)

        return full_diarized_text_str, translated_output_str, f"Detected overall language: {detected_language}", output_filename

    except Exception as e:
        import traceback
        error_message = f"An error occurred: {e}\n{traceback.format_exc()}"
        print(error_message)
        # Clean up temp audio file even on error
        if audio_file_path and os.path.exists(audio_file_path):
            os.unlink(audio_file_path)
        return error_message, "", "", None

# --- Gradio Interface ---
with gr.Blocks(title="Language-Agnostic Speaker Diarization, Transcription, and Translation") as demo:
    gr.Markdown(
        """
        # Language-Agnostic Speaker Diarization, Transcription, and Translation
        Upload an audio file (WAV, MP3, etc.) or record directly from your microphone. 
        The system will identify speakers, transcribe their speech (in detected language), 
        and provide an English translation for relevant segments.
        """
    )

    with gr.Row():
        audio_input = gr.Audio(
            type="filepath", 
            sources=["upload", "microphone"], 
            label="Upload Audio File or Record from Microphone"
        )
        
    with gr.Row():
        process_button = gr.Button("Process Audio", variant="primary")

    with gr.Column():
        detected_language_output = gr.Textbox(label="Detected Overall Language")
        # Diarized Transcription will still be chronological with speaker labels
        diarized_transcription_output = gr.Textbox(label="Diarized Transcription (Chronological with Speaker Labels)", lines=10, interactive=False)
        # Translated transcription will now be clearly separated by speaker
        translated_transcription_output = gr.Textbox(label="Translated Transcription (to English, per Speaker)", lines=10, interactive=False)
        
        download_button = gr.File(label="Download Transcription (.txt)", interactive=False, visible=False)

    process_button.click(
        fn=process_audio_for_web,
        inputs=audio_input,
        outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button]
    )

    gr.Examples(
        [
            # Add paths to your example audio files here.
            # These files must be present in your Hugging Face Space repository.
            # For example, if you have 'sample_two_speakers.wav' in your repo:
            # "sample_two_speakers.wav" 
        ],
        inputs=audio_input,
        outputs=[diarized_transcription_output, translated_transcription_output, detected_language_output, download_button],
        fn=process_audio_for_web,
        cache_examples=False 
    )

demo.launch()