Spaces:
Runtime error
Runtime error
| import datetime | |
| import functools | |
| import importlib | |
| import io | |
| import os | |
| import sys | |
| from typing import Any, Dict, List | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| from huggingface_hub import snapshot_download, upload_file | |
| token = os.getenv("HUGGINGFACE_TOKEN") | |
| def main() -> None: | |
| torch.manual_seed(0) | |
| model = build_model() | |
| # components to render | |
| title_markdown = gr.Markdown("# 音源分離のデモ") | |
| # audio | |
| audio_description_markdown = gr.Markdown( | |
| """ | |
| ## ユーザ入力の収集について | |
| このデモでは,ユーザの皆様のアップロードされた音声ファイルを収集します. | |
| ユーザの皆さまが本デモを使用することにより,音声ファイルデータの収集に同意されたものとみなします. | |
| 同意されない場合は,本デモをご利用にならないようにお願いいたします. | |
| """ | |
| ) | |
| upload_audio = gr.Audio(type="filepath") | |
| audio_submission_button = gr.Button("分離を実行") | |
| display_audios = [] | |
| for _ in range(model.num_sources): | |
| display_audios.append(gr.Audio()) | |
| # text | |
| text_description_markdown = gr.Markdown( | |
| """ | |
| ## ユーザ入力の収集について | |
| このデモでは,ユーザの皆様のアップロードされたテキストを収集します. | |
| ユーザの皆さまが本デモを使用することにより,テキストデータの収集に同意されたものとみなします. | |
| 同意されない場合は,本デモをご利用にならないようにお願いいたします. | |
| """ | |
| ) | |
| upload_text = gr.Textbox() | |
| text_submission_button = gr.Button("テキストを送信") | |
| with gr.Blocks() as demo: | |
| title_markdown.render() | |
| with gr.Tab("音源分離"): | |
| audio_description_markdown.render() | |
| upload_audio.render() | |
| audio_submission_button.render() | |
| with gr.Group(visible=False) as display_audio_group: | |
| for display_audio in display_audios: | |
| display_audio: gr.Audio | |
| display_audio.render() | |
| audio_submission_button.click( | |
| functools.partial( | |
| perform_separation, | |
| model=model, | |
| display_audio_group=display_audio_group, | |
| display_audios=display_audios, | |
| ), | |
| inputs=[upload_audio], | |
| outputs=[display_audio_group] + display_audios, | |
| ) | |
| with gr.Tab("テキストの収集"): | |
| text_description_markdown.render() | |
| upload_text.render() | |
| text_submission_button.render() | |
| text_submission_button.click( | |
| functools.partial( | |
| submit_text, | |
| upload_text=upload_text, | |
| ), | |
| inputs=[upload_text], | |
| outputs=[upload_text], | |
| ) | |
| demo.launch() | |
| def build_model() -> nn.Module: | |
| if token is None: | |
| raise ValueError("Please set HUGGINGFACE_TOKEN to download model.") | |
| model_repo_id = "tky823/DummyModel" | |
| download_dir = ".download" | |
| local_dir = os.path.join(download_dir, model_repo_id) | |
| snapshot_download( | |
| repo_id=model_repo_id, | |
| local_dir=local_dir, | |
| token=token, | |
| ) | |
| # assumption | |
| # - <local_dir> | |
| # |- src/ | |
| # |- ... | |
| # |- model.pth | |
| # |- ... | |
| sys.path.append(os.path.join(local_dir, "src")) | |
| model_path = os.path.join(local_dir, "model.pth") | |
| state_dict = torch.load(model_path) | |
| model_state_dict = state_dict["model"]["state_dict"] | |
| model_cls: str = state_dict["model"]["cls"] | |
| model_args: List[Any] = state_dict["model"]["args"] | |
| model_kwargs: Dict[str, Any] = state_dict["model"]["kwargs"] | |
| module_name, cls_name = model_cls.rsplit(".", maxsplit=1) | |
| module = importlib.import_module(module_name) | |
| cls = getattr(module, cls_name) | |
| model: nn.Module = cls(*model_args, **model_kwargs) | |
| model.load_state_dict(model_state_dict) | |
| return model | |
| def perform_separation( | |
| audio: str, | |
| model: nn.Module = None, | |
| display_audio_group: gr.Group = None, | |
| display_audios: List[gr.Audio] = None, | |
| ) -> Dict[Any, Any]: | |
| # UTC+9: Japan | |
| timezone = datetime.timezone(datetime.timedelta(hours=9)) | |
| if model is None: | |
| raise ValueError("model is always required.") | |
| if display_audio_group is None: | |
| raise ValueError("display_audio_group is always required.") | |
| if display_audios is None: | |
| raise ValueError("display_audios is always required.") | |
| if isinstance(audio, str): | |
| audio_path = audio | |
| mixture, sample_rate = torchaudio.load(audio_path) | |
| if token: | |
| now = datetime.datetime.now(timezone) | |
| filename = os.path.basename(audio_path) | |
| _, ext = os.path.splitext(filename) | |
| now = now.strftime("%Y%m%d_%H%M%S.%f") | |
| path_in_repo = os.path.join("audio", now + ext) | |
| dataset_repo_id = "tky823/DummyDataset" | |
| repo_type = "dataset" | |
| try: | |
| upload_file( | |
| path_or_fileobj=audio_path, | |
| path_in_repo=path_in_repo, | |
| repo_id=dataset_repo_id, | |
| token=token, | |
| repo_type=repo_type, | |
| ) | |
| except Exception: | |
| # give up uploading | |
| pass | |
| else: | |
| raise ValueError("Invalid type of audio is given.") | |
| model.eval() | |
| with torch.no_grad(): | |
| mixture = mixture.mean(dim=0, keepdim=True) | |
| mixture = mixture.unsqueeze(dim=0) # insert batch dimension | |
| separated = model(mixture) | |
| separated = separated.squeeze(dim=0) # remove batch dimension | |
| update_components = { | |
| display_audio_group: gr.Group(visible=True), | |
| } | |
| assert separated.size(0) == len(display_audios) | |
| separated = separated.numpy() | |
| separated = to_int16(separated) | |
| for source_idx in range(len(display_audios)): | |
| display_audio = display_audios[source_idx] | |
| separated_source = separated[source_idx] | |
| update_components[display_audio] = gr.Audio((sample_rate, separated_source)) | |
| return update_components | |
| def submit_text( | |
| text: str, upload_text: gr.Textbox = None, encoding: str = "utf-8" | |
| ) -> Dict[Any, Any]: | |
| # UTC+9: Japan | |
| timezone = datetime.timezone(datetime.timedelta(hours=9)) | |
| if upload_text is None: | |
| raise ValueError("upload_text is always required.") | |
| if token: | |
| now = datetime.datetime.now(timezone) | |
| now = now.strftime("%Y%m%d_%H%M%S.%f") | |
| encoded_text = text.encode(encoding) | |
| encoded_text = io.BytesIO(encoded_text) | |
| path_in_repo = os.path.join("text", now + ".txt") | |
| dataset_repo_id = "tky823/DummyDataset" | |
| repo_type = "dataset" | |
| try: | |
| upload_file( | |
| path_or_fileobj=encoded_text, | |
| path_in_repo=path_in_repo, | |
| repo_id=dataset_repo_id, | |
| token=token, | |
| repo_type=repo_type, | |
| ) | |
| except Exception: | |
| # give up uploading | |
| pass | |
| update_components = { | |
| upload_text: gr.Textbox(), | |
| } | |
| return update_components | |
| def to_int16(waveform: np.ndarray) -> np.ndarray: | |
| bits_per_sample = 16 | |
| waveform = waveform * 2 ** (bits_per_sample - 1) | |
| waveform = waveform.astype(np.int16) | |
| return waveform | |
| if __name__ == "__main__": | |
| main() | |