DummySpace / app.py
tky823
Upload text
dc3b40d
import datetime
import functools
import importlib
import io
import os
import sys
from typing import Any, Dict, List
import gradio as gr
import numpy as np
import torch
import torch.nn as nn
import torchaudio
from huggingface_hub import snapshot_download, upload_file
token = os.getenv("HUGGINGFACE_TOKEN")
def main() -> None:
torch.manual_seed(0)
model = build_model()
# components to render
title_markdown = gr.Markdown("# 音源分離のデモ")
# audio
audio_description_markdown = gr.Markdown(
"""
## ユーザ入力の収集について
このデモでは,ユーザの皆様のアップロードされた音声ファイルを収集します.
ユーザの皆さまが本デモを使用することにより,音声ファイルデータの収集に同意されたものとみなします.
同意されない場合は,本デモをご利用にならないようにお願いいたします.
"""
)
upload_audio = gr.Audio(type="filepath")
audio_submission_button = gr.Button("分離を実行")
display_audios = []
for _ in range(model.num_sources):
display_audios.append(gr.Audio())
# text
text_description_markdown = gr.Markdown(
"""
## ユーザ入力の収集について
このデモでは,ユーザの皆様のアップロードされたテキストを収集します.
ユーザの皆さまが本デモを使用することにより,テキストデータの収集に同意されたものとみなします.
同意されない場合は,本デモをご利用にならないようにお願いいたします.
"""
)
upload_text = gr.Textbox()
text_submission_button = gr.Button("テキストを送信")
with gr.Blocks() as demo:
title_markdown.render()
with gr.Tab("音源分離"):
audio_description_markdown.render()
upload_audio.render()
audio_submission_button.render()
with gr.Group(visible=False) as display_audio_group:
for display_audio in display_audios:
display_audio: gr.Audio
display_audio.render()
audio_submission_button.click(
functools.partial(
perform_separation,
model=model,
display_audio_group=display_audio_group,
display_audios=display_audios,
),
inputs=[upload_audio],
outputs=[display_audio_group] + display_audios,
)
with gr.Tab("テキストの収集"):
text_description_markdown.render()
upload_text.render()
text_submission_button.render()
text_submission_button.click(
functools.partial(
submit_text,
upload_text=upload_text,
),
inputs=[upload_text],
outputs=[upload_text],
)
demo.launch()
def build_model() -> nn.Module:
if token is None:
raise ValueError("Please set HUGGINGFACE_TOKEN to download model.")
model_repo_id = "tky823/DummyModel"
download_dir = ".download"
local_dir = os.path.join(download_dir, model_repo_id)
snapshot_download(
repo_id=model_repo_id,
local_dir=local_dir,
token=token,
)
# assumption
# - <local_dir>
# |- src/
# |- ...
# |- model.pth
# |- ...
sys.path.append(os.path.join(local_dir, "src"))
model_path = os.path.join(local_dir, "model.pth")
state_dict = torch.load(model_path)
model_state_dict = state_dict["model"]["state_dict"]
model_cls: str = state_dict["model"]["cls"]
model_args: List[Any] = state_dict["model"]["args"]
model_kwargs: Dict[str, Any] = state_dict["model"]["kwargs"]
module_name, cls_name = model_cls.rsplit(".", maxsplit=1)
module = importlib.import_module(module_name)
cls = getattr(module, cls_name)
model: nn.Module = cls(*model_args, **model_kwargs)
model.load_state_dict(model_state_dict)
return model
def perform_separation(
audio: str,
model: nn.Module = None,
display_audio_group: gr.Group = None,
display_audios: List[gr.Audio] = None,
) -> Dict[Any, Any]:
# UTC+9: Japan
timezone = datetime.timezone(datetime.timedelta(hours=9))
if model is None:
raise ValueError("model is always required.")
if display_audio_group is None:
raise ValueError("display_audio_group is always required.")
if display_audios is None:
raise ValueError("display_audios is always required.")
if isinstance(audio, str):
audio_path = audio
mixture, sample_rate = torchaudio.load(audio_path)
if token:
now = datetime.datetime.now(timezone)
filename = os.path.basename(audio_path)
_, ext = os.path.splitext(filename)
now = now.strftime("%Y%m%d_%H%M%S.%f")
path_in_repo = os.path.join("audio", now + ext)
dataset_repo_id = "tky823/DummyDataset"
repo_type = "dataset"
try:
upload_file(
path_or_fileobj=audio_path,
path_in_repo=path_in_repo,
repo_id=dataset_repo_id,
token=token,
repo_type=repo_type,
)
except Exception:
# give up uploading
pass
else:
raise ValueError("Invalid type of audio is given.")
model.eval()
with torch.no_grad():
mixture = mixture.mean(dim=0, keepdim=True)
mixture = mixture.unsqueeze(dim=0) # insert batch dimension
separated = model(mixture)
separated = separated.squeeze(dim=0) # remove batch dimension
update_components = {
display_audio_group: gr.Group(visible=True),
}
assert separated.size(0) == len(display_audios)
separated = separated.numpy()
separated = to_int16(separated)
for source_idx in range(len(display_audios)):
display_audio = display_audios[source_idx]
separated_source = separated[source_idx]
update_components[display_audio] = gr.Audio((sample_rate, separated_source))
return update_components
def submit_text(
text: str, upload_text: gr.Textbox = None, encoding: str = "utf-8"
) -> Dict[Any, Any]:
# UTC+9: Japan
timezone = datetime.timezone(datetime.timedelta(hours=9))
if upload_text is None:
raise ValueError("upload_text is always required.")
if token:
now = datetime.datetime.now(timezone)
now = now.strftime("%Y%m%d_%H%M%S.%f")
encoded_text = text.encode(encoding)
encoded_text = io.BytesIO(encoded_text)
path_in_repo = os.path.join("text", now + ".txt")
dataset_repo_id = "tky823/DummyDataset"
repo_type = "dataset"
try:
upload_file(
path_or_fileobj=encoded_text,
path_in_repo=path_in_repo,
repo_id=dataset_repo_id,
token=token,
repo_type=repo_type,
)
except Exception:
# give up uploading
pass
update_components = {
upload_text: gr.Textbox(),
}
return update_components
def to_int16(waveform: np.ndarray) -> np.ndarray:
bits_per_sample = 16
waveform = waveform * 2 ** (bits_per_sample - 1)
waveform = waveform.astype(np.int16)
return waveform
if __name__ == "__main__":
main()