|
|
import gradio as gr |
|
|
import onnx_asr |
|
|
|
|
|
from pydub import AudioSegment |
|
|
from pydub.effects import normalize |
|
|
import numpy as np |
|
|
import csv |
|
|
|
|
|
import os |
|
|
import pandas as pd |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
def convert_to_sentence_timestamps(timestamps, tokens): |
|
|
sentence_timestamps = [] |
|
|
start_time = None |
|
|
end_time = None |
|
|
current_tokens = [] |
|
|
|
|
|
for i, token in enumerate(tokens): |
|
|
if token in {'.', '!', '?'}: |
|
|
if start_time is not None: |
|
|
end_time = timestamps[i] |
|
|
current_tokens.append(token) |
|
|
segment = ''.join(current_tokens).strip() |
|
|
sentence_timestamps.append({ |
|
|
'start': f"{start_time:.2f}", |
|
|
'end': f"{end_time:.2f}", |
|
|
'segment': segment |
|
|
}) |
|
|
start_time = None |
|
|
end_time = None |
|
|
current_tokens = [] |
|
|
else: |
|
|
if start_time is None: |
|
|
start_time = timestamps[i] |
|
|
current_tokens.append(token) |
|
|
|
|
|
return sentence_timestamps |
|
|
|
|
|
|
|
|
providers = ['CPUExecutionProvider'] |
|
|
|
|
|
def process_audio(audio_file, chunk_duration): |
|
|
|
|
|
model = onnx_asr.load_model("nemo-parakeet-tdt-0.6b-v3", providers=providers).with_timestamps() |
|
|
|
|
|
try: |
|
|
|
|
|
sound = AudioSegment.from_file(audio_file, channels=1) |
|
|
|
|
|
|
|
|
ch = 1 |
|
|
sw = 2 |
|
|
fr = 16000 |
|
|
sound = normalize(sound) |
|
|
sound = sound.set_channels(ch) |
|
|
sound = sound.set_sample_width(sw) |
|
|
sound = sound.set_frame_rate(fr) |
|
|
|
|
|
|
|
|
chunk_duration = chunk_duration * 1000 |
|
|
total_duration = len(sound) |
|
|
|
|
|
start_time = 0 |
|
|
end_time = 0 |
|
|
final_chunk = 0 |
|
|
item = 0 |
|
|
sentence_timestamps = [] |
|
|
|
|
|
|
|
|
while start_time < total_duration: |
|
|
|
|
|
print(f"Start time:{start_time/1000:.2f}s") |
|
|
end_time = min(start_time + chunk_duration, total_duration) |
|
|
print(f"chunk: {start_time/1000:.2f}s - {end_time/1000:.2f}s") |
|
|
|
|
|
chunk = sound[start_time:end_time] |
|
|
chunk_len = len(chunk) |
|
|
if len(chunk) < chunk_duration: |
|
|
print("Final chunk start") |
|
|
final_chunk = 1 |
|
|
|
|
|
print(f"Current chunk length: {(chunk_len/1000):.2f}s") |
|
|
|
|
|
|
|
|
chunk_array = np.array(chunk.get_array_of_samples()) |
|
|
|
|
|
|
|
|
output = model.recognize(chunk_array) |
|
|
|
|
|
chunk_timestamps = convert_to_sentence_timestamps(output.timestamps, output.tokens) |
|
|
end_index = len(chunk_timestamps) - 2 if not final_chunk else len(chunk_timestamps) |
|
|
last_timestamp = start_time |
|
|
current_timestamps = [] |
|
|
for i in range(end_index): |
|
|
item += 1 |
|
|
timestamps = chunk_timestamps[i] |
|
|
timestamps['start'] = f"{(float(timestamps['start']) + start_time / 1000):.2f}" |
|
|
timestamps['end'] = f"{(float(timestamps['end']) + start_time / 1000):.2f}" |
|
|
last_timestamp = float(timestamps['end']) |
|
|
|
|
|
current_timestamps.append(timestamps) |
|
|
|
|
|
start_time = last_timestamp * 1000 |
|
|
|
|
|
|
|
|
sentence_timestamps.extend(current_timestamps) |
|
|
item += 1 |
|
|
if final_chunk == 1: |
|
|
break |
|
|
|
|
|
|
|
|
table_data = [] |
|
|
for i, timestamp in enumerate(sentence_timestamps): |
|
|
table_data.append([ |
|
|
i + 1, |
|
|
timestamp['start'], |
|
|
timestamp['end'], |
|
|
timestamp['segment'] |
|
|
]) |
|
|
|
|
|
return table_data, sentence_timestamps |
|
|
finally: |
|
|
|
|
|
del model |
|
|
|
|
|
import gc |
|
|
gc.collect() |
|
|
|
|
|
def save_csv(timestamps, filename): |
|
|
"""Save timestamps to CSV file""" |
|
|
|
|
|
if isinstance(timestamps, pd.DataFrame): |
|
|
|
|
|
df = timestamps |
|
|
else: |
|
|
|
|
|
df = pd.DataFrame(timestamps) |
|
|
|
|
|
|
|
|
if len(df.columns) >= 4: |
|
|
df.columns = ['Index', 'Start (s)', 'End (s)', 'Segment'] |
|
|
else: |
|
|
|
|
|
df = pd.DataFrame(timestamps) |
|
|
|
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
csv_filename = f"{filename}_{timestamp_str}.csv" |
|
|
csv_path = os.path.join("output", csv_filename) |
|
|
|
|
|
|
|
|
os.makedirs("output", exist_ok=True) |
|
|
|
|
|
|
|
|
df.to_csv(csv_path, index=False) |
|
|
return csv_path |
|
|
|
|
|
def save_srt(timestamps, filename): |
|
|
"""Save timestamps to SRT file""" |
|
|
|
|
|
if isinstance(timestamps, pd.DataFrame): |
|
|
df = timestamps |
|
|
else: |
|
|
|
|
|
df = pd.DataFrame(timestamps) |
|
|
|
|
|
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
srt_filename = f"{filename}_{timestamp_str}.srt" |
|
|
srt_path = os.path.join("output", srt_filename) |
|
|
|
|
|
|
|
|
os.makedirs("output", exist_ok=True) |
|
|
|
|
|
|
|
|
srt_content = [] |
|
|
for i, row in df.iterrows(): |
|
|
|
|
|
if isinstance(row, pd.Series): |
|
|
|
|
|
index = i + 1 |
|
|
|
|
|
start_time = float(row['start']) if 'start' in row else float(row.iloc[0]) |
|
|
end_time = float(row['end']) if 'end' in row else float(row.iloc[1]) |
|
|
segment = str(row['segment']) if 'segment' in row else str(row.iloc[2]) |
|
|
else: |
|
|
|
|
|
try: |
|
|
index = i + 1 |
|
|
start_time = float(row[0]) |
|
|
end_time = float(row[1]) |
|
|
segment = str(row[2]) |
|
|
except (ValueError, IndexError): |
|
|
|
|
|
continue |
|
|
|
|
|
|
|
|
def seconds_to_srt_time(seconds): |
|
|
hours = int(seconds // 3600) |
|
|
minutes = int((seconds % 3600) // 60) |
|
|
secs = int(seconds % 60) |
|
|
millisecs = int((seconds % 1) * 1000) |
|
|
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millisecs:03d}" |
|
|
|
|
|
srt_content.append(str(index)) |
|
|
srt_content.append(f"{seconds_to_srt_time(start_time)} --> {seconds_to_srt_time(end_time)}") |
|
|
srt_content.append(segment) |
|
|
srt_content.append("") |
|
|
|
|
|
with open(srt_path, 'w', encoding='utf-8') as f: |
|
|
f.write('\n'.join(srt_content)) |
|
|
|
|
|
return srt_path |
|
|
|
|
|
def download_csv(timestamps): |
|
|
"""Download timestamps as CSV""" |
|
|
try: |
|
|
csv_path = save_csv(timestamps, "timestamps") |
|
|
return csv_path |
|
|
except Exception as e: |
|
|
print(f"Error in download_csv: {e}") |
|
|
return None |
|
|
|
|
|
def download_srt(timestamps): |
|
|
"""Download timestamps as SRT""" |
|
|
try: |
|
|
srt_path = save_srt(timestamps, "timestamps") |
|
|
return srt_path |
|
|
except Exception as e: |
|
|
print(f"Error in download_srt: {e}") |
|
|
return None |
|
|
|
|
|
def generate_files(timestamps): |
|
|
csv_path = download_csv(timestamps) |
|
|
srt_path = download_srt(timestamps) |
|
|
new_csv_btn = gr.DownloadButton(label="Download CSV", value=csv_path, visible=True) |
|
|
new_srt_btn = gr.DownloadButton(label="Download SRT", value=srt_path, visible=True) |
|
|
return new_csv_btn, new_srt_btn |
|
|
|
|
|
custom_css = """ |
|
|
.cell-menu-button{ |
|
|
display: none !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=custom_css) as demo: |
|
|
gr.Markdown("# Nvidia Parakeet v3 Timestamp Processor") |
|
|
gr.Markdown("Upload an audio file, then click Transcribe to process timestamps with parakeet-tdt-0.6b-v3-onnx.") |
|
|
gr.Markdown("This is a CPU space, so expect slow processing: about 250s per 10m audio. But total duration isn't limited.") |
|
|
|
|
|
timestamps_state = gr.State() |
|
|
|
|
|
with gr.Row(): |
|
|
audio_input = gr.Audio(type="filepath", label="Upload Audio File") |
|
|
chunk_duration_slider = gr.Slider( |
|
|
minimum=10, |
|
|
maximum=400, |
|
|
value=150, |
|
|
step=1, |
|
|
label="Chunk Duration (seconds)" |
|
|
) |
|
|
|
|
|
transcribe_btn = gr.Button("Transcribe") |
|
|
|
|
|
with gr.Row(): |
|
|
csv_btn = gr.DownloadButton(label="Download CSV", visible=False) |
|
|
srt_btn = gr.DownloadButton(label="Download SRT", visible=False) |
|
|
|
|
|
with gr.Row(): |
|
|
table_output = gr.Dataframe( |
|
|
headers=["Index", "Start (s)", "End (s)", "Segment"], |
|
|
datatype=["number", "number", "number", "str"], |
|
|
label="Timestamps", |
|
|
wrap=True, |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transcribe_btn.click( |
|
|
fn=process_audio, |
|
|
inputs=[audio_input, chunk_duration_slider], |
|
|
outputs=[table_output, timestamps_state] |
|
|
) |
|
|
|
|
|
timestamps_state.change( |
|
|
fn=generate_files, |
|
|
inputs=[timestamps_state], |
|
|
outputs=[csv_btn, srt_btn] |
|
|
) |
|
|
|
|
|
demo.launch() |