File size: 12,639 Bytes
4d4eaf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5bff499
 
 
 
 
 
 
 
 
 
4d4eaf0
 
5bff499
 
4d4eaf0
 
5bff499
 
 
 
 
 
 
 
 
 
 
 
 
4d4eaf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddae84a
4d4eaf0
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import os
import torch
import gradio as gr
from pathlib import Path
import numpy as np
from pyannote.audio import Pipeline
from transformers import pipeline
from datasets import Dataset
import librosa
import soundfile as sf
from datetime import timedelta
from transformers.utils import is_flash_attn_2_available

class DiarizationTranscriptionTranslation:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        
        # Load models for diarization, transcription, and translation
        self.load_models()
    
    def load_models(self):
        """Load all required models"""
        print("Loading models...")
        
        # Initialize the optimized ASR pipeline using transformers with a smaller, faster model
        # Using whisper-base for faster download and less space usage while maintaining good speed
        model_name = "openai/whisper-base"  # Using base which is much smaller than large variants
        
        self.asr_pipeline = pipeline(
            "automatic-speech-recognition",
            model=model_name,
            torch_dtype=torch.float32,  # Using float32 for better compatibility with base model
            device="cuda:0" if self.device == "cuda" else "cpu",
        )
        
        # Load translation model
        self.translation_pipeline = pipeline(
            "translation",
            model="Helsinki-NLP/opus-mt-mul-en",
            device=0 if self.device == "cuda" else -1
        )
        
        # Initialize the diarization pipeline
        # Note: This requires authentication with Hugging Face for pyannote.audio models
        # For Hugging Face Spaces, authentication is handled automatically if set up in the space settings
        try:
            # Try to load with auth token first
            self.diarization_pipeline = Pipeline.from_pretrained(
                "pyannote/speaker-diarization@2.1",
                use_auth_token=True
            )
            print("Diarization model loaded successfully!")
        except Exception as e:
            print(f"Could not load diarization model with auth: {e}")
            try:
                # Try to load without auth (for public models or when cache is available)
                self.diarization_pipeline = Pipeline.from_pretrained(
                    "pyannote/speaker-diarization@2.1"
                )
                print("Diarization model loaded successfully from cache!")
            except Exception as e2:
                print(f"Could not load diarization model: {e2}")
                print("Using fallback diarization method")
                self.diarization_pipeline = None
            
        print("Models loaded successfully!")
    
    def load_audio(self, file_path, sr=16000):
        """Load audio file and return waveform and sample rate"""
        audio, orig_sr = librosa.load(file_path, sr=None)
        if sr is not None and orig_sr != sr:
            audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr)
        return audio, orig_sr
    
    def diarize_audio(self, file_path):
        """Perform speaker diarization on the audio file"""
        if self.diarization_pipeline is None:
            # Fallback to simple approach if diarization model not available
            duration = librosa.get_duration(path=file_path)
            segments = []
            current_time = 0
            segment_duration = min(5.0, duration)  # 5 second segments max
            
            speaker_id = 1
            while current_time < duration:
                end_time = min(current_time + segment_duration, duration)
                segments.append({
                    "start": current_time,
                    "end": end_time,
                    "speaker": f"Speaker {speaker_id}"
                })
                current_time = end_time
                speaker_id = 2 if speaker_id == 1 else 1  # Alternate speakers
            
            return segments
        else:
            # Use pyannote.audio for proper diarization
            from pyannote.audio import Audio
            from pyannote.core import Annotation
            
            # Load the audio file
            audio = Audio().read(file_path)
            
            # Apply diarization
            diarization = self.diarization_pipeline({"waveform": audio[0], "sample_rate": audio[1]})
            
            # Convert diarization to segments
            segments = []
            for segment, track, speaker in diarization.itertracks(yield_label=True):
                segments.append({
                    "start": segment.start,
                    "end": segment.end,
                    "speaker": speaker
                })
            
            return segments
    
    def transcribe_audio(self, file_path):
        """Transcribe audio using optimized Whisper with insanely-fast-whisper techniques"""
        try:
            # The ASR pipeline can work directly with file paths
            # Using optimized parameters from insanely-fast-whisper for speed
            result = self.asr_pipeline(
                file_path,
                chunk_length_s=30,
                batch_size=24,  # Increased batch size for faster processing
                max_new_tokens=128,
                return_timestamps=True,
                generate_kwargs={"task": "transcribe"}  # Explicitly set task for better performance
            )
            return result
        except Exception as e:
            print(f"Error during transcription: {str(e)}")
            return {"text": f"Transcription error: {str(e)}"}
    
    def detect_language(self, text):
        """Detect language of the transcribed text"""
        # Use langdetect for language detection
        try:
            from langdetect import detect
            return detect(text)
        except ImportError:
            # If langdetect is not available, use a simple heuristic
            # This is a basic heuristic - in practice, you'd want a proper language detection model
            non_english_chars = sum(1 for c in text if ord(c) > 127)
            if non_english_chars / len(text) > 0.1:  # If more than 10% are non-English chars
                # For demo purposes, return a common non-English language
                # A real implementation would use a proper language detection model
                return "es"  # Spanish as example
            return "en"
        except:
            # If detection fails, return English as default
            return "en"
    
    def translate_text(self, text):
        """Translate text to English"""
        try:
            # First detect the language
            detected_lang = self.detect_language(text)
            
            # If the text is not in English, translate it
            if detected_lang != "en":
                result = self.translation_pipeline(text)
                return result[0]['translation_text']
            else:
                return text
        except Exception as e:
            print(f"Translation error: {str(e)}")
            # If translation fails, return the original text
            return text
    
    def process_audio(self, file_path):
        """Main processing pipeline: diarization -> transcription -> translation"""
        if not file_path:
            return "Please upload an audio file."
        
        try:
            # Step 1: Diarize the audio
            diarization_segments = self.diarize_audio(file_path)
            
            # Step 2: For each segment, transcribe and then translate
            results = []
            
            for segment in diarization_segments:
                # Extract the audio segment
                audio, orig_sr = self.load_audio(file_path)
                start_sample = int(segment["start"] * orig_sr)
                end_sample = int(segment["end"] * orig_sr)
                segment_audio = audio[start_sample:end_sample]
                
                # Ensure segment_audio is not empty
                if len(segment_audio) == 0:
                    continue  # Skip empty segments
                
                # Add a small amount of silence if segment is too short for Whisper
                if len(segment_audio) < orig_sr * 0.1:  # Less than 0.1 seconds
                    min_samples = int(orig_sr * 0.1)
                    zeros_to_add = min_samples - len(segment_audio)
                    segment_audio = np.pad(segment_audio, (0, zeros_to_add), mode='constant')
                
                # Save the segment as a temporary file for Whisper
                temp_file = f"temp_segment_{segment['start']}_{segment['end']}.wav"
                # Use subtype parameter to ensure proper WAV format
                sf.write(temp_file, segment_audio, orig_sr, format='WAV', subtype='PCM_16')
                
                # Transcribe the segment
                try:
                    transcription_result = self.transcribe_audio(temp_file)
                    # Handle both possible return formats
                    if isinstance(transcription_result, dict) and "text" in transcription_result:
                        transcribed_text = transcription_result["text"]
                    elif isinstance(transcription_result, str):
                        transcribed_text = transcription_result
                    else:
                        transcribed_text = str(transcription_result)
                except Exception as e:
                    print(f"Error transcribing segment {temp_file}: {str(e)}")
                    transcribed_text = f"Transcription error: {str(e)}"
                    # Continue with the error message as the transcription
                
                # Translate if necessary
                translated_text = self.translate_text(transcribed_text)
                
                results.append({
                    "start": segment["start"],
                    "end": segment["end"],
                    "speaker": segment["speaker"],
                    "transcription": transcribed_text,
                    "translation": translated_text
                })
                
                # Clean up temp file
                try:
                    os.remove(temp_file)
                except:
                    pass  # Ignore errors when removing temp file
            
            # Format results as a readable transcript
            transcript = []
            for result in results:
                start_time = str(timedelta(seconds=int(result["start"])))
                end_time = str(timedelta(seconds=int(result["end"])))
                
                # Show both original transcription and translation if different
                if result["transcription"] != result["translation"]:
                    transcript.append(
                        f"[{start_time} - {end_time}] {result['speaker']}:\n"
                        f"  Original: {result['transcription']}\n"
                        f"  Translation: {result['translation']}\n"
                    )
                else:
                    transcript.append(
                        f"[{start_time} - {end_time}] {result['speaker']}: {result['translation']}"
                    )
            
            return "\n".join(transcript)
        
        except Exception as e:
            return f"Processing error: {str(e)}"


# Initialize the processor
processor = DiarizationTranscriptionTranslation()

def process_audio_file(audio_file):
    """Wrapper function for Gradio interface"""
    if audio_file is None:
        return "Please upload an audio file."
    
    return processor.process_audio(audio_file)

# Create Gradio interface
interface = gr.Interface(
    fn=process_audio_file,
    inputs=gr.Audio(type="filepath", label="Upload Audio File"),
    outputs=gr.Textbox(label="Transcript with Speaker Labels", lines=20),
    title="Speaker Diarization, Transcription & Translation",
    description="""
    This Space combines three powerful speech processing capabilities:
    
    1. Speaker Diarization - Distinguishes between different speakers in your audio
    2. Speech Transcription - Converts spoken words into text using Whisper
    3. Automatic Translation - Translates non-English content to English
    
    Upload an audio file (MP3, WAV, or other common formats) and get a timestamped transcript with speaker labels.
    """,
    examples=[
        # Add example files if available
    ],
    cache_examples=False
)

if __name__ == "__main__":
    interface.launch(server_name="0.0.0.0", server_port=7860)