cgus commited on
Commit
4fa2818
·
verified ·
1 Parent(s): 67c97ff

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -0
app.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnx_asr
3
+ #import torch
4
+ from pydub import AudioSegment
5
+ from pydub.effects import normalize
6
+ import numpy as np
7
+ import csv
8
+ #import pprint
9
+ import os
10
+ import pandas as pd
11
+ from datetime import datetime
12
+
13
+ # Function to convert timestamps into sentence timestamps
14
+ def convert_to_sentence_timestamps(timestamps, tokens):
15
+ sentence_timestamps = []
16
+ start_time = None
17
+ end_time = None
18
+ current_tokens = []
19
+
20
+ for i, token in enumerate(tokens):
21
+ if token in {'.', '!', '?'}:
22
+ if start_time is not None:
23
+ end_time = timestamps[i]
24
+ current_tokens.append(token)
25
+ segment = ''.join(current_tokens).strip()
26
+ sentence_timestamps.append({
27
+ 'start': f"{start_time:.2f}",
28
+ 'end': f"{end_time:.2f}",
29
+ 'segment': segment
30
+ })
31
+ start_time = None
32
+ end_time = None
33
+ current_tokens = []
34
+ else:
35
+ if start_time is None:
36
+ start_time = timestamps[i]
37
+ current_tokens.append(token)
38
+
39
+ return sentence_timestamps
40
+
41
+ #providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
42
+ providers = ['CPUExecutionProvider']
43
+
44
+ def process_audio(audio_file, chunk_duration):
45
+ # Load model here (only when needed)
46
+ model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3", providers=providers).with_timestamps()
47
+
48
+ try:
49
+ # Load audio file
50
+ sound = AudioSegment.from_file(audio_file, channels=1)
51
+
52
+ # Process audio
53
+ ch = 1
54
+ sw = 2
55
+ fr = 16000
56
+ sound = normalize(sound)
57
+ sound = sound.set_channels(ch)
58
+ sound = sound.set_sample_width(sw) # PCM_16 format
59
+ sound = sound.set_frame_rate(fr)
60
+
61
+ # Process audio in X second chunks
62
+ chunk_duration = chunk_duration * 1000 # X seconds in milliseconds
63
+ total_duration = len(sound)
64
+
65
+ start_time = 0
66
+ end_time = 0
67
+ final_chunk = 0
68
+ item = 0
69
+ sentence_timestamps = []
70
+
71
+
72
+ while start_time < total_duration:
73
+ # Calculate end time for this chunk
74
+ print(f"Start time:{start_time/1000:.2f}s")
75
+ end_time = min(start_time + chunk_duration, total_duration)
76
+ print(f"chunk: {start_time/1000:.2f}s - {end_time/1000:.2f}s")
77
+ # Extract audio chunk
78
+ chunk = sound[start_time:end_time]
79
+ chunk_len = len(chunk)
80
+ if len(chunk) < chunk_duration:
81
+ print("Final chunk start")
82
+ final_chunk = 1
83
+
84
+ print(f"Current chunk length: {(chunk_len/1000):.2f}s")
85
+
86
+ # Convert chunk to numpy array
87
+ chunk_array = np.array(chunk.get_array_of_samples())
88
+
89
+ # Process chunk
90
+ output = model.recognize(chunk_array)
91
+
92
+ chunk_timestamps = convert_to_sentence_timestamps(output.timestamps, output.tokens)
93
+ end_index = len(chunk_timestamps) - 2 if not final_chunk else len(chunk_timestamps)
94
+ last_timestamp = start_time
95
+ current_timestamps = []
96
+ for i in range(end_index):
97
+ item += 1
98
+ timestamps = chunk_timestamps[i]
99
+ timestamps['start'] = f"{(float(timestamps['start']) + start_time / 1000):.2f}"
100
+ timestamps['end'] = f"{(float(timestamps['end']) + start_time / 1000):.2f}"
101
+ last_timestamp = float(timestamps['end'])
102
+
103
+ current_timestamps.append(timestamps)
104
+
105
+ start_time = last_timestamp * 1000
106
+
107
+ # Add timestamps with global offset
108
+ sentence_timestamps.extend(current_timestamps)
109
+ item += 1
110
+ if final_chunk == 1:
111
+ break
112
+
113
+ # Convert to table format
114
+ table_data = []
115
+ for i, timestamp in enumerate(sentence_timestamps):
116
+ table_data.append([
117
+ i + 1,
118
+ timestamp['start'],
119
+ timestamp['end'],
120
+ timestamp['segment']
121
+ ])
122
+
123
+ return table_data, sentence_timestamps
124
+ finally:
125
+ # Clean up model after processing
126
+ del model
127
+ # Optional: Force garbage collection
128
+ import gc
129
+ gc.collect()
130
+
131
+ def save_csv(timestamps, filename):
132
+ """Save timestamps to CSV file"""
133
+ # Convert timestamps to proper format if needed
134
+ if isinstance(timestamps, pd.DataFrame):
135
+ # If it's already a DataFrame, use it directly
136
+ df = timestamps
137
+ else:
138
+ # If it's a list or other format, convert it
139
+ df = pd.DataFrame(timestamps)
140
+
141
+ # Ensure we have the right column names
142
+ if len(df.columns) >= 4:
143
+ df.columns = ['Index', 'Start (s)', 'End (s)', 'Segment']
144
+ else:
145
+ # Handle case where we get a list of dicts or similar
146
+ df = pd.DataFrame(timestamps)
147
+
148
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
149
+ csv_filename = f"{filename}_{timestamp_str}.csv"
150
+ csv_path = os.path.join("output", csv_filename)
151
+
152
+ # Ensure output directory exists
153
+ os.makedirs("output", exist_ok=True)
154
+
155
+ # Save the dataframe
156
+ df.to_csv(csv_path, index=False)
157
+ return csv_path
158
+
159
+ def save_srt(timestamps, filename):
160
+ """Save timestamps to SRT file"""
161
+ # Convert to proper format if needed
162
+ if isinstance(timestamps, pd.DataFrame):
163
+ df = timestamps
164
+ else:
165
+ # Convert list of dicts to DataFrame
166
+ df = pd.DataFrame(timestamps)
167
+
168
+ timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S")
169
+ srt_filename = f"{filename}_{timestamp_str}.srt"
170
+ srt_path = os.path.join("output", srt_filename)
171
+
172
+ # Ensure output directory exists
173
+ os.makedirs("output", exist_ok=True)
174
+
175
+ # Generate SRT content
176
+ srt_content = []
177
+ for i, row in df.iterrows():
178
+ # Handle both DataFrame rows and list/dict formats
179
+ if isinstance(row, pd.Series):
180
+ # For DataFrame case, extract values by column name
181
+ index = i + 1
182
+ #pprint.pprint(row)
183
+ start_time = float(row['start']) if 'start' in row else float(row.iloc[0])
184
+ end_time = float(row['end']) if 'end' in row else float(row.iloc[1])
185
+ segment = str(row['segment']) if 'segment' in row else str(row.iloc[2])
186
+ else:
187
+ # Handle list/dict format - properly extract data
188
+ try:
189
+ index = i + 1
190
+ start_time = float(row[0]) # start time (index 1)
191
+ end_time = float(row[1]) # end time (index 2)
192
+ segment = str(row[2]) # segment text (index 3)
193
+ except (ValueError, IndexError):
194
+ # If conversion fails or index is out of bounds, skip this row
195
+ continue
196
+
197
+ # Convert seconds to SRT time format
198
+ def seconds_to_srt_time(seconds):
199
+ hours = int(seconds // 3600)
200
+ minutes = int((seconds % 3600) // 60)
201
+ secs = int(seconds % 60)
202
+ millisecs = int((seconds % 1) * 1000)
203
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}"
204
+
205
+ srt_content.append(str(index))
206
+ srt_content.append(f"{seconds_to_srt_time(start_time)} --> {seconds_to_srt_time(end_time)}")
207
+ srt_content.append(segment)
208
+ srt_content.append("") # Empty line between subtitles
209
+
210
+ with open(srt_path, 'w', encoding='utf-8') as f:
211
+ f.write('\n'.join(srt_content))
212
+
213
+ return srt_path
214
+
215
+ def download_csv(timestamps):
216
+ """Download timestamps as CSV"""
217
+ try:
218
+ csv_path = save_csv(timestamps, "timestamps")
219
+ return csv_path
220
+ except Exception as e:
221
+ print(f"Error in download_csv: {e}")
222
+ return None
223
+
224
+ def download_srt(timestamps):
225
+ """Download timestamps as SRT"""
226
+ try:
227
+ srt_path = save_srt(timestamps, "timestamps")
228
+ return srt_path
229
+ except Exception as e:
230
+ print(f"Error in download_srt: {e}")
231
+ return None
232
+
233
+ def generate_files(timestamps):
234
+ csv_path = download_csv(timestamps)
235
+ srt_path = download_srt(timestamps)
236
+ new_csv_btn = gr.DownloadButton(label="Download CSV", value=csv_path, visible=True)
237
+ new_srt_btn = gr.DownloadButton(label="Download SRT", value=srt_path, visible=True)
238
+ return new_csv_btn, new_srt_btn
239
+ # Add CSS to hide sort buttons
240
+ custom_css = """
241
+ .cell-menu-button{
242
+ display: none !important;
243
+ }
244
+ """
245
+
246
+ with gr.Blocks(css=custom_css) as demo:
247
+ gr.Markdown("# Nvidia Parakeet v3 Timestamp Processor")
248
+ gr.Markdown("Upload an audio file, then click Transcribe to process timestamps with parakeet-tdt-0.6b-v3-onnx.")
249
+
250
+ timestamps_state = gr.State()
251
+
252
+ with gr.Row():
253
+ audio_input = gr.Audio(type="filepath", label="Upload Audio File")
254
+ chunk_duration_slider = gr.Slider(
255
+ minimum=10,
256
+ maximum=400,
257
+ value=150,
258
+ step=1,
259
+ label="Chunk Duration (seconds)"
260
+ )
261
+
262
+ transcribe_btn = gr.Button("Transcribe")
263
+
264
+ with gr.Row():
265
+ csv_btn = gr.DownloadButton(label="Download CSV", visible=False)
266
+ srt_btn = gr.DownloadButton(label="Download SRT", visible=False)
267
+
268
+ with gr.Row():
269
+ table_output = gr.Dataframe(
270
+ headers=["Index", "Start (s)", "End (s)", "Segment"],
271
+ datatype=["number", "number", "number", "str"],
272
+ label="Timestamps",
273
+ interactive=False
274
+ )
275
+
276
+
277
+
278
+ # Process audio when button is clicked
279
+ transcribe_btn.click(
280
+ fn=process_audio,
281
+ inputs=[audio_input, chunk_duration_slider],
282
+ outputs=[table_output, timestamps_state]
283
+ )
284
+
285
+ timestamps_state.change(
286
+ fn=generate_files,
287
+ inputs=[timestamps_state],
288
+ outputs=[csv_btn, srt_btn]
289
+ )
290
+
291
+ demo.launch()