File size: 9,859 Bytes
4fa2818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58cc467
4fa2818
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import gradio as gr
import onnx_asr
#import torch
from pydub import AudioSegment
from pydub.effects import normalize
import numpy as np
import csv
#import pprint
import os
import pandas as pd
from datetime import datetime

# Function to convert timestamps into sentence timestamps
def convert_to_sentence_timestamps(timestamps, tokens):
    sentence_timestamps = []
    start_time = None
    end_time = None
    current_tokens = []

    for i, token in enumerate(tokens):
        if token in {'.', '!', '?'}:
            if start_time is not None:
                end_time = timestamps[i]
                current_tokens.append(token)
                segment = ''.join(current_tokens).strip()
                sentence_timestamps.append({
                    'start': f"{start_time:.2f}",
                    'end': f"{end_time:.2f}",
                    'segment': segment
                })
                start_time = None
                end_time = None
                current_tokens = []
        else:
            if start_time is None:
                start_time = timestamps[i]
            current_tokens.append(token)

    return sentence_timestamps

#providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
providers = ['CPUExecutionProvider']

def process_audio(audio_file, chunk_duration):
    # Load model here (only when needed)
    model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3", providers=providers).with_timestamps()

    try:
        # Load audio file
        sound = AudioSegment.from_file(audio_file, channels=1)

        # Process audio
        ch = 1
        sw = 2
        fr = 16000
        sound = normalize(sound)
        sound = sound.set_channels(ch)
        sound = sound.set_sample_width(sw)  # PCM_16 format
        sound = sound.set_frame_rate(fr)

        # Process audio in X second chunks
        chunk_duration = chunk_duration * 1000  # X seconds in milliseconds
        total_duration = len(sound)

        start_time = 0
        end_time = 0
        final_chunk = 0
        item = 0
        sentence_timestamps = []


        while start_time < total_duration:
            # Calculate end time for this chunk
            print(f"Start time:{start_time/1000:.2f}s")
            end_time = min(start_time + chunk_duration, total_duration)
            print(f"chunk: {start_time/1000:.2f}s - {end_time/1000:.2f}s")
            # Extract audio chunk
            chunk = sound[start_time:end_time]
            chunk_len = len(chunk)
            if len(chunk) < chunk_duration:
                print("Final chunk start")
                final_chunk = 1

            print(f"Current chunk length: {(chunk_len/1000):.2f}s")

            # Convert chunk to numpy array
            chunk_array = np.array(chunk.get_array_of_samples())

            # Process chunk
            output = model.recognize(chunk_array)

            chunk_timestamps = convert_to_sentence_timestamps(output.timestamps, output.tokens)
            end_index = len(chunk_timestamps) - 2 if not final_chunk else len(chunk_timestamps)
            last_timestamp = start_time
            current_timestamps = []
            for i in range(end_index):
                item += 1
                timestamps = chunk_timestamps[i]
                timestamps['start'] = f"{(float(timestamps['start']) + start_time / 1000):.2f}"
                timestamps['end'] = f"{(float(timestamps['end']) + start_time / 1000):.2f}"
                last_timestamp = float(timestamps['end'])

                current_timestamps.append(timestamps)

            start_time = last_timestamp * 1000

            # Add timestamps with global offset
            sentence_timestamps.extend(current_timestamps)
            item += 1
            if final_chunk == 1:
                break

        # Convert to table format
        table_data = []
        for i, timestamp in enumerate(sentence_timestamps):
            table_data.append([
                i + 1,
                timestamp['start'],
                timestamp['end'],
                timestamp['segment']
            ])

        return table_data, sentence_timestamps
    finally:
        # Clean up model after processing
        del model
        # Optional: Force garbage collection
        import gc
        gc.collect()

def save_csv(timestamps, filename):
    """Save timestamps to CSV file"""
    # Convert timestamps to proper format if needed
    if isinstance(timestamps, pd.DataFrame):
        # If it's already a DataFrame, use it directly
        df = timestamps
    else:
        # If it's a list or other format, convert it
        df = pd.DataFrame(timestamps)

    # Ensure we have the right column names
    if len(df.columns) >= 4:
        df.columns = ['Index', 'Start (s)', 'End (s)', 'Segment']
    else:
        # Handle case where we get a list of dicts or similar
        df = pd.DataFrame(timestamps)

    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    csv_filename = f"{filename}_{timestamp_str}.csv"
    csv_path = os.path.join("output", csv_filename)

    # Ensure output directory exists
    os.makedirs("output", exist_ok=True)

    # Save the dataframe
    df.to_csv(csv_path, index=False)
    return csv_path

def save_srt(timestamps, filename):
    """Save timestamps to SRT file"""
    # Convert to proper format if needed
    if isinstance(timestamps, pd.DataFrame):
        df = timestamps
    else:
        # Convert list of dicts to DataFrame
        df = pd.DataFrame(timestamps)

    timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
    srt_filename = f"{filename}_{timestamp_str}.srt"
    srt_path = os.path.join("output", srt_filename)

    # Ensure output directory exists
    os.makedirs("output", exist_ok=True)

    # Generate SRT content
    srt_content = []
    for i, row in df.iterrows():
        # Handle both DataFrame rows and list/dict formats
        if isinstance(row, pd.Series):
            # For DataFrame case, extract values by column name
            index = i + 1
            #pprint.pprint(row)
            start_time = float(row['start']) if 'start' in row else float(row.iloc[0])
            end_time = float(row['end']) if 'end' in row else float(row.iloc[1])
            segment = str(row['segment']) if 'segment' in row else str(row.iloc[2])
        else:
            # Handle list/dict format - properly extract data
            try:
                index = i + 1
                start_time = float(row[0])  # start time (index 1)
                end_time = float(row[1])    # end time (index 2)
                segment = str(row[2])     # segment text (index 3)
            except (ValueError, IndexError):
                # If conversion fails or index is out of bounds, skip this row
                continue

        # Convert seconds to SRT time format
        def seconds_to_srt_time(seconds):
            hours = int(seconds // 3600)
            minutes = int((seconds % 3600) // 60)
            secs = int(seconds % 60)
            millisecs = int((seconds % 1) * 1000)
            return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}"

        srt_content.append(str(index))
        srt_content.append(f"{seconds_to_srt_time(start_time)} --> {seconds_to_srt_time(end_time)}")
        srt_content.append(segment)
        srt_content.append("")  # Empty line between subtitles

    with open(srt_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(srt_content))

    return srt_path

def download_csv(timestamps):
    """Download timestamps as CSV"""
    try:
        csv_path = save_csv(timestamps, "timestamps")
        return csv_path
    except Exception as e:
        print(f"Error in download_csv: {e}")
        return None

def download_srt(timestamps):
    """Download timestamps as SRT"""
    try:
        srt_path = save_srt(timestamps, "timestamps")
        return srt_path
    except Exception as e:
        print(f"Error in download_srt: {e}")
        return None

def generate_files(timestamps):
    csv_path = download_csv(timestamps)
    srt_path = download_srt(timestamps)
    new_csv_btn = gr.DownloadButton(label="Download CSV", value=csv_path, visible=True)
    new_srt_btn = gr.DownloadButton(label="Download SRT", value=srt_path, visible=True)
    return new_csv_btn, new_srt_btn 
# Add CSS to hide sort buttons
custom_css = """
.cell-menu-button{
    display: none !important;
}
"""

with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# Nvidia Parakeet v3 Timestamp Processor")
    gr.Markdown("Upload an audio file, then click Transcribe to process timestamps with parakeet-tdt-0.6b-v3-onnx.")
    gr.Markdown("This is a CPU space, so expect slow processing: about 250s per 10m audio. But total duration isn't limited.")

    timestamps_state = gr.State()

    with gr.Row():
        audio_input = gr.Audio(type="filepath", label="Upload Audio File")
        chunk_duration_slider = gr.Slider(
            minimum=10,
            maximum=400,
            value=150,
            step=1,
            label="Chunk Duration (seconds)"
        )

    transcribe_btn = gr.Button("Transcribe")

    with gr.Row():
        csv_btn = gr.DownloadButton(label="Download CSV", visible=False)
        srt_btn = gr.DownloadButton(label="Download SRT", visible=False)

    with gr.Row():
        table_output = gr.Dataframe(
            headers=["Index", "Start (s)", "End (s)", "Segment"],
            datatype=["number", "number", "number", "str"],
            label="Timestamps",
            interactive=False
        )



    # Process audio when button is clicked
    transcribe_btn.click(
        fn=process_audio,
        inputs=[audio_input, chunk_duration_slider],
        outputs=[table_output, timestamps_state]
    )

    timestamps_state.change(
        fn=generate_files,
        inputs=[timestamps_state],
        outputs=[csv_btn, srt_btn]
    )

demo.launch()