template_match_asr / tabs /annotation.py
HoneyTian's picture
first commit
e8cd021
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from cProfile import label
from functools import lru_cache
import json
from pathlib import Path
import platform
import shutil
import tempfile
import time
from typing import Dict, List, Union
import uuid
import gradio as gr
from gradio_client import Client, handle_file
import librosa
import numpy as np
from python_speech_features import sigproc
from scipy.io import wavfile
from tqdm import tqdm
from project_settings import environment, project_path
from toolbox.vm.vad import Tagger, correct_labels, split_signal_by_labels
from toolbox.early_media.template_match import AudioTemplateMatch, wave_to_spectrum
def convert_sample_rate(signal: np.ndarray, sample_rate: int, target_sample_rate: int):
filename = save_input_audio(sample_rate, signal)
signal, _ = librosa.load(filename, sr=target_sample_rate)
signal = np.array(signal * (1 << 15), dtype=np.int16)
return signal
def save_input_audio(sample_rate: int, signal: np.ndarray) -> str:
temp_audio_dir = Path(tempfile.gettempdir()) / "input_audio"
temp_audio_dir.mkdir(parents=True, exist_ok=True)
filename = temp_audio_dir / f"{uuid.uuid4()}.wav"
filename = filename.as_posix()
wavfile.write(
filename,
sample_rate, signal
)
return filename
def generate_short_id():
timestamp = int(time.time() * 1000) # 毫秒级时间戳
return f"{timestamp}"[-5:]
def get_annotation_tab(trained_model_dir: str, win_size: float = 2.0, win_step: float = 0.25):
# prepare
trained_model_dir = Path(trained_model_dir)
# audio_split_dataset samples
audio_split_dataset_samples = []
model_name_to_labels = {
"sound-2-l3-ch32-cnn": ["voice", "non_voice"],
"sound-8-l3-ch32-cnn": ["voice", "voicemail", "white_noise", "mute", "noise_mute", "bell", "music", "noise"],
}
model_name_choices = list(model_name_to_labels.keys())
tagger_dict: Dict[str, Tagger] = {
model_name: Tagger(
model_file=trained_model_dir / f"{model_name}.zip",
win_size=win_size,
win_step=win_step,
) for model_name in model_name_to_labels.keys()
}
asr_model_name_dict: Dict[str, Client] = {
"whisper-large-v3": "hf-audio/whisper-large-v3",
}
asr_model_name_choices = list(asr_model_name_dict.keys())
@lru_cache(maxsize=100)
def get_asr_client(src: str):
client = Client(src)
return client
def when_click_move_wav_button(template_dir: str, voice_dir: str):
template_dir = Path(template_dir)
config_json_file = template_dir / "config.json"
with open(config_json_file.as_posix(), "r", encoding="utf-8") as f:
config_json = json.load(f)
audio_template_match = AudioTemplateMatch(
wave_to_spectrum=wave_to_spectrum,
sample_rate=8000,
template_crop=0.1,
threshold=0.007,
)
gr.Info("loading templates ...")
audio_template_match.load_template(path=template_dir)
count = 0
gr.Info("search similar templates ...")
voice_dir = Path(voice_dir)
for filename in tqdm(voice_dir.glob("**/*.wav")):
sample_rate, signal = wavfile.read(filename.as_posix())
if sample_rate != 8000:
raise gr.Error("sample rate not 8000")
matches = audio_template_match.template_match_by_wave(wave=signal)
if len(matches) == 0:
continue
matches_ = list()
for match in matches:
label = match["label"]
matches_.append({
**match,
"weight": config_json[label].get("weight", 0.0)
})
matches_ = list(sorted(matches_, key=lambda x: x["weight"], reverse=True))
labels = [match["label"] for match in matches_]
labels_ = [label for label in labels if label not in ("music",)]
if len(set(labels_)) > 1:
print("超过两个模板类别被匹配,请检测是否匹配正确。")
print(filename)
for match in matches_:
print(match)
# continue
labels_ = labels_[:1]
if len(labels_) == 0:
label = "music"
else:
label = labels_[0]
if filename.parts[-2] != label:
tgt = filename.parent.parent / label
os.makedirs(tgt, exist_ok=True)
try:
shutil.move(filename.as_posix(), tgt.as_posix())
count += 1
except shutil.Error:
print(filename)
print(tgt)
continue
gr.Info(f"resolved wav count: {count}")
return
def when_click_save_template_button(audio_t, relative_name: str, template_dir: str):
if relative_name.startswith("/"):
relative_name = relative_name[1:]
parts = relative_name.split("/")
if len(parts) != 2:
raise gr.Error(
f"relative_name should with a label and basename. such as `busy/busy.wav`, instead of `{relative_name}`")
template_dir = Path(template_dir)
dst_file = template_dir / relative_name
dst_file.parent.mkdir(parents=True, exist_ok=True)
if dst_file.exists():
raise gr.Error()
dst_file = dst_file.as_posix()
sample_rate, signal = audio_t
if sample_rate != 8000:
signal = convert_sample_rate(signal, sample_rate, 8000)
sample_rate = 8000
src_file = save_input_audio(sample_rate, signal)
shutil.move(src=src_file, dst=dst_file)
return dst_file, dst_file
def when_click_search_similar_template_button(audio_t, template_dir: str):
audio_template_match = AudioTemplateMatch(
wave_to_spectrum=wave_to_spectrum,
sample_rate=8000,
template_crop=0.1,
threshold=0.7,
)
gr.Info("loading templates ...")
audio_template_match.load_template(path=template_dir)
sample_rate, signal = audio_t
if sample_rate != 8000:
raise gr.Error("sample rate not 8000")
gr.Info("search similar templates ...")
similar_templates_dataset_state = list()
matches = audio_template_match.template_match_by_wave(wave=signal)
for match in matches:
label = match["label"]
filename = match["filename"]
min_val = match["min_val"]
min_val = round(min_val, 4)
similar_templates_dataset_state.append([label, filename, min_val])
similar_templates_dataset_state = list(sorted(similar_templates_dataset_state, key=lambda x: x[2]))
gr.Info(f"similar templates count: {len(similar_templates_dataset_state)}")
similar_templates_dataset = gr.Dataset(
components=[t_template_label, t_template_audio, t_template_score],
samples=similar_templates_dataset_state,
)
return similar_templates_dataset_state, similar_templates_dataset
def when_click_two_second_button(audio_file: str, transcribe: str):
sample_rate, signal = wavfile.read(audio_file)
frames = sigproc.framesig(
sig=signal,
frame_len=win_size * sample_rate,
frame_step=win_step * sample_rate,
# winfunc=np.hamming
)
two_second_audio_dataset_state = list()
basename = Path(audio_file).stem
for j, frame in enumerate(frames):
out_root = Path(tempfile.gettempdir()) / "template_match_asr"
out_root.mkdir(parents=True, exist_ok=True)
to_filename = out_root / f"{basename}_{j}.wav"
frame = np.array(frame, dtype=np.int16)
wavfile.write(
filename=to_filename,
rate=sample_rate,
data=frame
)
two_second_audio_dataset_state.append([
to_filename.as_posix(),
transcribe,
])
two_second_audio_dataset = gr.Dataset(
components=[t_two_second_audio],
samples=two_second_audio_dataset_state,
)
return two_second_audio_dataset_state, two_second_audio_dataset
def when_click_do_asr_button(audio_file: str, dataset_state: List[list], model_name: str = "whisper-large-v3"):
dataset_state = dataset_state if isinstance(dataset_state, list) else dataset_state.value
client_src = asr_model_name_dict.get(model_name)
asr_client: Client = get_asr_client(client_src)
if asr_client is None:
raise AssertionError(f"invalid asr model name: {model_name}.")
transcribe = asr_client.predict(
inputs=handle_file(audio_file),
task="transcribe",
api_name="/predict_1"
)
new_split_audio_samples_state = list()
for sample in dataset_state:
sample_file = sample[0]
if sample_file == audio_file:
row = [sample_file, transcribe]
else:
row = sample
new_split_audio_samples_state.append(row)
audio_split_dataset = gr.Dataset(
components=[t_split_audio, t_split_audio_transcribe],
samples=new_split_audio_samples_state,
visible=True
)
return transcribe, dataset_state, audio_split_dataset
def when_click_audio_split_button(audio_t, model_name: str = "sound-2-ch32", target_label: str = "voice"):
sample_rate, signal = audio_t
if len(signal) < 2.0 * sample_rate:
raise gr.Error("audio duration should be great than 2 second.")
# raise AssertionError("audio duration should be great than 2 second.")
tagger: Tagger = tagger_dict.get(model_name)
if tagger is None:
raise gr.Error("invalud model name: {model_name}.")
# raise AssertionError(f"invalud model name: {model_name}.")
labels = tagger.tag(signal / (1 << 15))
labels = correct_labels(labels, target_label=target_label)
if "voice" not in labels:
raise gr.Error("no voice split found.")
# raise AssertionError("no voice split found.")
sub_signal_list = split_signal_by_labels(signal, labels)
audio_split_files = list()
for i, sub_signal_group in enumerate(sub_signal_list):
out_root = Path(tempfile.gettempdir()) / "template_match_asr"
out_root.mkdir(parents=True, exist_ok=True)
to_filename = out_root / f"{generate_short_id()}.wav"
sub_signal = sub_signal_group["sub_signal"]
sub_signal = np.array(sub_signal, dtype=np.int16)
wavfile.write(
filename=to_filename.as_posix(),
rate=sample_rate,
data=sub_signal
)
audio_split_files.append(to_filename.as_posix())
split_audio_dataset_state = [[name, None] for name in audio_split_files]
audio_split_dataset = gr.Dataset(
components=[t_split_audio, t_split_audio_transcribe],
samples=split_audio_dataset_state,
visible=True
)
return split_audio_dataset_state, audio_split_dataset
def when_click_next_audio_button(workdir: str):
workdir = Path(workdir)
choices = list(workdir.glob("*.wav"))
gr.Info(f"rest count: {len(choices)}")
if len(choices) == 0:
raise gr.Error(f"no audio anymore.")
filename = choices[0]
return filename, filename
def when_click_move_button(filename: str, move_target_dir_name: str):
if move_target_dir_name == "music":
music_dir = Path(t_wav_music_dir.value)
music_dir.mkdir(parents=True, exist_ok=True)
shutil.move(filename, music_dir.as_posix())
elif move_target_dir_name == "bell":
bell_dir = Path(t_wav_bell_dir.value)
bell_dir.mkdir(parents=True, exist_ok=True)
shutil.move(filename, bell_dir.as_posix())
else:
raise gr.Error(f"invalid move to dir: {move_target_dir_name}")
return None, None
def when_click_delete_button(filename: str):
os.remove(filename)
return None, None
# ui
with gr.TabItem("settings", visible=True):
t_wav_voice_dir = gr.Textbox(
value=(project_path / environment.get("wav_voice_dir")).as_posix(),
label="wav_voice_dir"
)
t_wav_bell_dir = gr.Textbox(
value=(project_path / environment.get("wav_bell_dir")).as_posix(),
label="wav_bell_dir"
)
t_wav_music_dir = gr.Textbox(
value=(project_path / environment.get("wav_music_dir")).as_posix(),
label="wav_music_dir"
)
t_templates_dir = gr.Textbox(
value=(project_path / environment.get("templates_dir")).as_posix(),
label="templates_dir"
)
with gr.TabItem("annotation", visible=True):
with gr.Row():
with gr.Column():
t_template_file = gr.Textbox(visible=True, label="template_file")
with gr.Row():
t_template_label = gr.Textbox(label="template_label")
t_template_score = gr.Textbox(label="template_score")
t_template_audio = gr.Audio(label="template_audio")
t_similar_templates_dataset_state = gr.State(value=[])
t_similar_templates_dataset = gr.Dataset(
components=[t_template_label, t_template_audio, t_template_score],
samples=t_similar_templates_dataset_state.value,
)
t_similar_templates_dataset.click(
fn=lambda x: (
"/".join(Path(x[1]).parts[-2:]),
x[1], x[0], x[2]
),
inputs=[t_similar_templates_dataset],
outputs=[t_template_file, t_template_audio, t_template_label, t_template_score]
)
# two second audio
with gr.Column():
t_two_second_audio_file = gr.Textbox(visible=True, label="two_second_audio_file")
t_two_second_audio = gr.Audio(label="two_second_audio")
with gr.Row():
t_search_similar_button = gr.Button(value="search_similar", variant="primary")
t_save_template_button = gr.Button(value="save_template", variant="primary")
t_search_similar_button.click(
fn=when_click_search_similar_template_button,
inputs=[t_two_second_audio, t_templates_dir],
outputs=[
t_similar_templates_dataset_state, t_similar_templates_dataset
],
)
t_save_template_button.click(
fn=when_click_save_template_button,
inputs=[t_two_second_audio, t_two_second_audio_file, t_templates_dir],
outputs=[
t_two_second_audio_file, t_two_second_audio
],
)
t_two_second_audio_dataset_state = gr.State(value=[])
t_two_second_audio_dataset = gr.Dataset(
components=[t_two_second_audio],
samples=t_two_second_audio_dataset_state.value,
)
t_two_second_audio_dataset.click(
fn=lambda x: (
str(x[0]).split("/")[-1],
x[0]
),
inputs=[t_two_second_audio_dataset],
outputs=[t_two_second_audio_file, t_two_second_audio]
)
t_two_second_audio_dataset_clear = gr.Button(value="clear 2s spits", variant="primary")
t_two_second_audio_dataset_clear.click(
fn=lambda: (
gr.State(value=[]),
gr.Dataset(
components=[t_two_second_audio],
samples=t_two_second_audio_dataset_state.value,
)
),
inputs=None,
outputs=[t_two_second_audio_dataset_state, t_two_second_audio_dataset]
)
# split audio
with gr.Column():
t_split_audio_file = gr.Textbox(visible=False, label="split_audio_file")
t_split_audio = gr.Audio(label="split_audio")
t_split_audio_transcribe = gr.Textbox(label="split_audio_transcribe")
t_asr_model_name = gr.Dropdown(choices=asr_model_name_choices, value=asr_model_name_choices[0],
label="asr_model")
with gr.Row():
t_asr_button = gr.Button(value="do_asr", variant="primary")
t_two_second = gr.Button(value="two_second", variant="primary")
t_split_audio_dataset_state = gr.State(value=[])
t_audio_split_dataset = gr.Dataset(
components=[t_split_audio, t_split_audio_transcribe],
samples=t_split_audio_dataset_state.value,
)
t_audio_split_dataset.click(
fn=lambda x: (x[0], x[0], x[1]),
inputs=[t_audio_split_dataset],
outputs=[t_split_audio_file, t_split_audio, t_split_audio_transcribe]
)
t_asr_button.click(
fn=when_click_do_asr_button,
inputs=[t_split_audio_file, t_split_audio_dataset_state, t_asr_model_name],
outputs=[t_split_audio_transcribe, t_split_audio_dataset_state, t_audio_split_dataset]
)
t_two_second.click(
fn=when_click_two_second_button,
inputs=[t_split_audio_file, t_split_audio_transcribe],
outputs=[t_two_second_audio_dataset_state, t_two_second_audio_dataset]
)
# input audio
with gr.Column():
t_input_audio_file = gr.Text(visible=False, label="input_audio_file")
t_input_audio = gr.Audio(label="input_audio")
with gr.Row():
t_next_audio_button = gr.Button(value="next", variant="primary")
t_delete_button = gr.Button(value="delete", variant="primary")
t_move_wav = gr.Button(value="move_wav", variant="primary")
t_model_name = gr.Dropdown(
choices=model_name_choices, value=model_name_choices[0], label="model_name"
)
t_target_label = gr.Dropdown(choices=["voice"], value="voice", label="target_label")
t_model_name.change(
fn=lambda x: gr.Dropdown(
choices=model_name_to_labels[x],
value=model_name_to_labels[x][0],
label="target_label"
),
inputs=t_model_name, outputs=t_target_label,
)
t_audio_split_button = gr.Button(value="audio_split", variant="primary")
t_move_dir_name = gr.Dropdown(
choices=["music", "bell"], value=None, label="move_dir_name"
)
t_move_button = gr.Button(value="move", variant="primary")
t_next_audio_button.click(
fn=when_click_next_audio_button,
inputs=[t_wav_voice_dir],
outputs=[t_input_audio_file, t_input_audio],
)
t_audio_split_button.click(
fn=when_click_audio_split_button,
inputs=[t_input_audio, t_model_name, t_target_label],
outputs=[t_split_audio_dataset_state, t_audio_split_dataset],
)
t_move_button.click(
fn=when_click_move_button,
inputs=[t_input_audio_file, t_move_dir_name],
outputs=[t_input_audio_file, t_input_audio],
)
t_delete_button.click(
fn=when_click_delete_button,
inputs=[t_input_audio_file],
outputs=[t_input_audio_file, t_input_audio],
)
t_move_wav.click(
fn=when_click_move_wav_button,
inputs=[t_templates_dir, t_wav_voice_dir],
outputs=None,
)
return locals()
if __name__ == "__main__":
pass