File size: 10,437 Bytes
8c85b97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
# core/transcription_service.py
import whisperx
import gradio as gr
import gc
import time
import os
from datetime import datetime
import warnings

# Suppress specific pyannote/pytorch warning about degrees of freedom
warnings.filterwarnings("ignore", message="std\(\): degrees of freedom is <= 0")

from src.config.settings import Config

# CORRECT WAY: Import DiarizationPipeline at point of use
from whisperx.diarize import DiarizationPipeline

class TranscriptionService:
    def __init__(self):
        self.config = Config
        self.models_loaded = False
        self.whisper_model = None
        self.diarize_model = None
        self.batch_size = 16

        
    def load_models(self):
        """Load AI models once - use pre-loaded models from init"""
        if not self.models_loaded:
            print("πŸ“₯ Loading transcription models...")
            
            # Use the model from config instead of hardcoding
            self.whisper_model = whisperx.load_model(
                self.config.WHISPER_MODEL,
                self.config.DEVICE,
                compute_type=self.config.COMPUTE_TYPE,
                language="en"
            )
            
            self.diarize_model = DiarizationPipeline(
                use_auth_token=self.config.HUGGINGFACE_TOKEN,
                device=self.config.DEVICE
            )
            
            self.models_loaded = True
            print("βœ… Models loaded successfully")


    def transcribe_video(self, video_file_path, progress_callback=None):
            """Clean transcription pipeline without Gradio dependencies.
            Added optional progress callback"""
            try:
                if not self.models_loaded:
                    self.load_models()
                
                start_time = time.time()
                print(f"🎬 Processing video: {os.path.basename(video_file_path)}")
                
                # ======================
                # STEP 1: Load Audio from Video
                # ======================
                if progress_callback:
                    print(f"DEBUG: Calling progress callback 0.1. Type: {type(progress_callback)}")
                    try:
                        progress_callback(0.1, desc="🎬 Loading audio from video...")
                        print("DEBUG: Progress callback 0.1 called successfully")
                    except Exception as e:
                        print(f"DEBUG: Error calling progress callback: {e}")
                    time.sleep(0.5)
                print("1️⃣ Loading audio directly from video...")
                audio = whisperx.load_audio(video_file_path)

                print(f"βœ… Audio loaded: {len(audio)} samples")
                
                # ======================
                # STEP 2: Transcribe with Whisper
                # ======================
                print("2️⃣ Loading Whisper model...")
                if progress_callback:
                    progress_callback(0.3, desc="πŸ€– Loading Whisper model...")
                    time.sleep(0.5)

                if progress_callback:
                    progress_callback(0.4, desc="πŸ“ Transcribing audio...")
                    time.sleep(0.5)
                print("3️⃣ Transcribing audio...")

                result = self.whisper_model.transcribe(audio, batch_size=self.batch_size)
                detected_language = result['language']  # Save language before it gets lost
                print(f"βœ… Transcription complete ({detected_language} detected)")            
                
                # ======================
                # STEP 3: Align Timestamps
                # ======================
                if progress_callback:
                    progress_callback(0.5, desc="⏱️ Aligning timestamps...")
                    time.sleep(0.5)
                print("4️⃣ Aligning word-level timestamps...")
                
                # Load the alignment model and its metadata from whisperx for word-level timestamp alignment.
                model_a, metadata = whisperx.load_align_model(
                    language_code=detected_language,
                    device=self.config.DEVICE
                )
                result = whisperx.align(
                    result["segments"],
                    model_a,
                    metadata,
                    audio,
                    self.config.DEVICE,
                    return_char_alignments=False
                )
                # Restore language to result dict after alignment
                result["language"] = detected_language
                print("βœ… Timestamps aligned")
                
                # ======================
                # STEP 4: Speaker Diarization - CORRECT IMPORT
                # ======================
                if progress_callback:
                    progress_callback(0.7, desc="πŸ‘₯ Identifying speakers...")
                    time.sleep(0.5)
                print("5️⃣ Loading speaker diarization model...")
                diarize_segments = self.diarize_model(audio)            
                    
                
                # ======================
                # STEP 5: Assign speakers
                # ======================
                #
                if progress_callback:
                    progress_callback(0.9, desc="πŸ”— Assigning speakers to text...")
                    time.sleep(0.5)
                result = whisperx.assign_word_speakers(diarize_segments, result)
                print("6️⃣ Assigning speakers to transcript...")
                             
                print("πŸ”— Assigning speakers to text...")
                result = whisperx.assign_word_speakers(diarize_segments, result)
                print("βœ… Speaker assignment complete")
                

                if progress_callback:
                    progress_callback(1.0, desc="βœ… Complete!")
                    time.sleep(0.5)
                    
                # ======================
                # STEP 6: Format results
                # ======================
                processing_time = time.time() - start_time
                transcription = self._format_results(result, video_file_path)
                timing_info = self._get_timing_info(result, processing_time, video_file_path)
                
                return {
                    "success": True,
                    "transcription": transcription,
                    "timing_info": timing_info,
                    "raw_data": result,  # Keep for potential storage
                    "processing_time": processing_time,
                    "speakers_count": len(set(seg.get("speaker", "UNKNOWN") for seg in result["segments"]))
                }
                
            except Exception as e:
                error_msg = f"Transcription failed: {str(e)}"
                print(f"❌ ERROR: {error_msg}")
                return {
                    "success": False,
                    "error": error_msg
                }
    

    def _format_results(self, result, video_file_path):
        """Format transcription with speaker labels and comprehensive meeting metadata"""
        if not result["segments"]:
            return "No transcription segments found"
        
        # Extract meeting metadata
        segments = result["segments"]
        speakers = set(segment.get("speaker", "UNKNOWN") for segment in segments)
        total_duration = segments[-1]["end"] if segments else 0
        language = result.get("language", "unknown")
        
        # Calculate statistics
        total_words = sum(len(seg.get("text", "").split()) for seg in segments)
        avg_segment_length = total_words / len(segments) if segments else 0
        
        # Build header with meeting context
        output = "# 🎯 Meeting Transcription\n\n"
        output += "## πŸ“‹ Meeting Information\n\n"
        output += f"**πŸ“ File:** `{os.path.basename(video_file_path)}`\n"
        output += f"**πŸ“… Date:** {datetime.now().strftime('%Y-%m-%d %H:%M')}\n"
        output += f"**⏱️ Duration:** {self._format_timestamp(total_duration)}\n"
        output += f"**πŸ‘₯ Speakers:** {len(speakers)}\n"
        output += f"**🌐 Language:** {language.upper()}\n"
        output += f"**πŸ€– Model:** {self.config.WHISPER_MODEL}\n\n"
        output += "---\n\n"
        output += "## πŸ’¬ Transcript\n\n"
        
        # Add transcript content
        current_speaker = None
        for segment in segments:
            speaker = segment.get("speaker", "UNKNOWN")
            start_time = self._format_timestamp(segment["start"])
            
            if speaker != current_speaker:
                output += f"\n**πŸ‘€ {speaker}:**\n"
                current_speaker = speaker
            
            output += f"[{start_time}] {segment['text'].strip()}\n"
        
        # Add comprehensive footer
        output += "\n---\n\n"
        output += "## πŸ“Š Transcript Statistics\n\n"
        output += f"**Total Segments:** {len(segments)}\n"
        output += f"**Total Words:** {total_words:,}\n"
        output += f"**Avg Words/Segment:** {avg_segment_length:.1f}\n"
        output += f"**Unique Speakers:** {len(speakers)}\n"
        output += f"**Speaker IDs:** {', '.join(sorted(speakers))}\n"
        
        return output
    
    def _get_timing_info(self, result, processing_time, video_file_path):
        """Generate timing information"""
        if not result["segments"]:
            return "No timing information available"
        
        total_duration = result["segments"][-1]["end"]
        speed_ratio = total_duration / processing_time if processing_time > 0 else 0
        video_name = os.path.basename(video_file_path)
        
        return f"""
## ⏱️ Processing Statistics

**File:** {video_name}
\n**Duration:** {self._format_timestamp(total_duration)}
**Processing Time:** {processing_time:.1f}s
\n**Speed:** {speed_ratio:.1f}x ({'Faster' if speed_ratio > 1 else 'Slower'} than real-time)
**Completed:** {datetime.now().strftime("%H:%M:%S")}
"""
    
    def _format_timestamp(self, seconds):
        """Convert seconds to MM:SS format"""
        if seconds is None:
            return "00:00"
        minutes = int(seconds // 60)
        seconds = int(seconds % 60)
        return f"{minutes:02d}:{seconds:02d}"