File size: 8,522 Bytes
15ca2ca
 
a87d440
 
7323bbb
 
 
a87d440
15ca2ca
 
 
 
 
7323bbb
ad6d387
 
 
 
 
15ca2ca
 
 
a87d440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7323bbb
a87d440
15ca2ca
b4d330b
 
a87d440
15ca2ca
 
a87d440
15ca2ca
 
a87d440
 
15ca2ca
 
 
 
ad6d387
 
 
 
a87d440
ad6d387
a87d440
15ca2ca
ad6d387
15ca2ca
 
 
 
 
ad6d387
 
 
 
 
15ca2ca
 
 
 
 
 
 
 
 
ad6d387
 
15ca2ca
 
a87d440
 
 
 
 
 
 
 
 
ad6d387
 
 
a87d440
ad6d387
a87d440
 
ad6d387
a87d440
 
 
 
 
ad6d387
 
 
 
 
a87d440
 
 
 
 
 
 
 
 
ad6d387
 
a87d440
 
 
 
 
 
 
ad6d387
a87d440
15ca2ca
 
a87d440
 
 
 
 
15ca2ca
a87d440
7323bbb
a87d440
 
 
 
 
 
 
15ca2ca
a87d440
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7323bbb
 
 
a87d440
 
 
 
 
7323bbb
 
 
 
a87d440
 
15ca2ca
 
7323bbb
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import streamlit as st
import numpy as np
import soundfile as sf
import librosa
import tempfile
import os
from concurrent.futures import ThreadPoolExecutor
from functools import partial

class AudioTranscriber:
    def __init__(self):
        self.model = None
        self.processor = None
        self.transcription_cache = {}
        self.max_segment_duration = 5.0  # Maximum segment duration in seconds
        
    def set_max_segment_duration(self, duration):
        """Set the maximum duration for any segment in seconds"""
        self.max_segment_duration = duration
        
    def load_model(self):
        """Load a lightweight transcription model"""
        if self.model is None:
            with st.spinner("Loading transcription model..."):
                try:
                    from transformers import pipeline
                    
                    # Use a small model for transcription to save memory
                    self.model = pipeline(
                        "automatic-speech-recognition", 
                        model="openai/whisper-small",
                        chunk_length_s=30,
                        device="cpu"
                    )
                except Exception as e:
                    st.warning(f"Error loading transcription model: {str(e)}. Using fallback method.")
                    self.model = None
        
        return self.model
    
    def segment_audio(self, audio_file, num_segments=5, min_segment_duration=3.0):
        """Segment the audio file into chunks for processing with minimum 3-second and maximum 5-second duration"""
        # Save the uploaded audio to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
            tmp_file.write(audio_file.getvalue())
            audio_path = tmp_file.name
        
        try:
            # Load the audio file
            y, sr = librosa.load(audio_path, sr=None)
            
            # Get total duration
            duration = librosa.get_duration(y=y, sr=sr)
            
            # Calculate ideal number of segments based on max_segment_duration
            # We want to create enough segments so that each is <= max_segment_duration
            ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1)
            
            # Ensure we don't create segments that are too short
            actual_segments = max(ideal_segments, int(duration / min_segment_duration))
            
            # Calculate segment duration
            segment_duration = min(duration / actual_segments, self.max_segment_duration)
            
            # Create segments
            segments = []
            timestamps = []
            
            # Create more segments to ensure each is under max_segment_duration
            current_time = 0
            while current_time < duration:
                start_time = current_time
                end_time = min(start_time + segment_duration, duration)
                
                # Convert time to samples
                start_sample = int(start_time * sr)
                end_sample = int(end_time * sr)
                
                # Extract segment
                segment = y[start_sample:end_sample]
                segments.append(segment)
                timestamps.append((start_time, end_time))
                
                current_time = end_time
            
            return segments, timestamps
            
        except Exception as e:
            st.warning(f"Error segmenting audio: {str(e)}. Using simplified segmentation.")
            
            # Fallback: Create equal segments
            try:
                y, sr = sf.read(audio_path)
                duration = len(y) / sr
                
                # Calculate ideal number of segments based on max_segment_duration
                ideal_segments = max(num_segments, int(duration / self.max_segment_duration) + 1)
                
                # Ensure we don't create segments that are too short
                actual_segments = max(ideal_segments, int(duration / min_segment_duration))
                
                # Calculate segment duration
                segment_duration = min(duration / actual_segments, self.max_segment_duration)
                
                # Create segments
                segments = []
                timestamps = []
                
                # Create more segments to ensure each is under max_segment_duration
                current_time = 0
                while current_time < duration:
                    start_time = current_time
                    end_time = min(start_time + segment_duration, duration)
                    
                    # Convert time to samples
                    start_sample = int(start_time * sr)
                    end_sample = int(end_time * sr)
                    
                    # Extract segment
                    segment = y[start_sample:end_sample]
                    segments.append(segment)
                    timestamps.append((start_time, end_time))
                    
                    current_time = end_time
                
                return segments, timestamps
                
            except Exception as inner_e:
                st.error(f"Critical error in audio segmentation: {str(inner_e)}")
                # Last resort: Create dummy segments
                segments = [np.zeros(16000) for _ in range(num_segments)]  # 1-second silent segments
                timestamps = [(i, min(i+1, i+self.max_segment_duration)) for i in range(num_segments)]
                return segments, timestamps
        finally:
            # Clean up temporary file
            if os.path.exists(audio_path):
                try:
                    os.unlink(audio_path)
                except:
                    pass
    
    def transcribe_segment(self, segment, sr=16000):
        """Transcribe a single audio segment"""
        # Generate a cache key based on the audio data
        import hashlib
        cache_key = hashlib.md5(segment.tobytes()).hexdigest()
        
        # Check if result is in cache
        if cache_key in self.transcription_cache:
            return self.transcription_cache[cache_key]
        
        try:
            # Load the model if not already loaded
            model = self.load_model()
            
            if model is not None:
                # Save segment to a temporary file
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
                    sf.write(tmp_file.name, segment, sr)
                    segment_path = tmp_file.name
                
                # Transcribe using the model
                result = model(segment_path)
                transcription = result["text"]
                
                # Clean up temporary file
                if os.path.exists(segment_path):
                    os.unlink(segment_path)
            else:
                # Fallback: Return empty string or placeholder
                transcription = "Audio content"
        except Exception as e:
            st.warning(f"Error transcribing segment: {str(e)}. Using fallback method.")
            # Fallback: Return empty string or placeholder
            transcription = "Audio content"
        
        # Cache the result
        self.transcription_cache[cache_key] = transcription
        
        return transcription
    
    def transcribe_segments(self, segments, sr=16000, parallel=False, max_workers=4):
        """Transcribe multiple audio segments with parallel processing"""
        if parallel and len(segments) > 1:
            # Process in parallel using ThreadPoolExecutor
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                # Create a partial function with fixed sample rate
                transcribe_func = partial(self.transcribe_segment, sr=sr)
                
                # Map and collect results
                transcriptions = list(executor.map(transcribe_func, segments))
        else:
            # Process sequentially
            transcriptions = []
            for segment in segments:
                transcription = self.transcribe_segment(segment, sr)
                transcriptions.append(transcription)
        
        return transcriptions
    
    def clear_cache(self):
        """Clear the transcription cache"""
        self.transcription_cache = {}
        return True