rvc_api / modules /ui.py
aryo100's picture
first commit
b5a064f
import importlib
import os
from typing import *
import gradio as gr
import gradio.routes
import torch
from . import models, shared
from .core import preload
from .shared import ROOT_DIR
class Tab:
TABS_DIR = os.path.join(ROOT_DIR, "modules", "tabs")
def __init__(self, filepath: str) -> None:
self.filepath = filepath
def sort(self):
return 1
def title(self):
return ""
def ui(self, outlet: Callable):
pass
def __call__(self):
children_dir = self.filepath[:-3]
children = []
if os.path.isdir(children_dir):
for file in os.listdir(children_dir):
if not file.endswith(".py"):
continue
module_name = file[:-3]
parent = os.path.relpath(Tab.TABS_DIR, Tab.TABS_DIR).replace("/", ".")
if parent.startswith("."):
parent = parent[1:]
if parent.endswith("."):
parent = parent[:-1]
children.append(
importlib.import_module(f"modules.tabs.{parent}.{module_name}")
)
children = sorted(children, key=lambda x: x.sort())
tabs = []
for child in children:
attrs = child.__dict__
tab = [x for x in attrs.values() if issubclass(x, Tab)]
if len(tab) > 0:
tabs.append(tab[0])
def outlet():
with gr.Tabs():
for tab in tabs:
with gr.Tab(tab.title()):
tab()
return self.ui(outlet)
def load_tabs() -> List[Tab]:
tabs = []
files = os.listdir(os.path.join(ROOT_DIR, "modules", "tabs"))
for file in files:
if not file.endswith(".py"):
continue
module_name = file[:-3]
module = importlib.import_module(f"modules.tabs.{module_name}")
attrs = module.__dict__
TabClass = [
x
for x in attrs.values()
if type(x) == type and issubclass(x, Tab) and not x == Tab
]
if len(TabClass) > 0:
tabs.append((file, TabClass[0]))
tabs = sorted([TabClass(file) for file, TabClass in tabs], key=lambda x: x.sort())
return tabs
def webpath(fn):
if fn.startswith(ROOT_DIR):
web_path = os.path.relpath(fn, ROOT_DIR).replace("\\", "/")
else:
web_path = os.path.abspath(fn)
return f"file={web_path}?{os.path.getmtime(fn)}"
def javascript_html():
script_js = os.path.join(ROOT_DIR, "script.js")
head = f'<script type="text/javascript" src="{webpath(script_js)}"></script>\n'
return head
def css_html():
return f'<link rel="stylesheet" property="stylesheet" href="{webpath(os.path.join(ROOT_DIR, "styles.css"))}">'
def create_head():
head = ""
head += css_html()
head += javascript_html()
def template_response(*args, **kwargs):
res = shared.gradio_template_response_original(*args, **kwargs)
res.body = res.body.replace(b"</head>", f"{head}</head>".encode("utf8"))
res.init_headers()
return res
gradio.routes.templates.TemplateResponse = template_response
def create_ui():
preload()
block = gr.Blocks()
with block:
with gr.Tabs():
tabs = load_tabs()
for tab in tabs:
with gr.Tab(tab.title()):
tab()
create_head()
return block
def create_model_list_ui(speaker_id: bool = True, load: bool = True):
speaker_id_info = {
"visible": False,
"maximum": 10000,
}
def reload_model(raw=False):
model_list = models.get_models()
if len(model_list) > 0:
models.load_model(model_list[0])
if models.vc_model is not None:
speaker_id_info["visible"] = True
speaker_id_info["maximum"] = models.vc_model.n_spk
return model_list if raw else gr.Dropdown.update(choices=model_list)
model_list = reload_model(raw=True)
def load_model(model_name):
if load:
models.load_model(model_name)
speaker_id_info["visible"] = True
speaker_id_info["maximum"] = models.vc_model.n_spk
else:
model = models.get_vc_model(model_name)
speaker_id_info["visible"] = True
speaker_id_info["maximum"] = model.n_spk
del model
torch.cuda.empty_cache()
return gr.Slider.update(
maximum=speaker_id_info["maximum"], visible=speaker_id_info["visible"]
)
with gr.Row(equal_height=False):
model = gr.Dropdown(
choices=model_list,
label="Model",
value=model_list[0] if len(model_list) > 0 else None,
)
speaker_id = gr.Slider(
minimum=0,
maximum=speaker_id_info["maximum"],
step=1,
label="Speaker ID",
value=0,
visible=speaker_id and speaker_id_info["visible"],
interactive=True,
)
reload_model_button = gr.Button("♻️")
model.change(load_model, inputs=[model], outputs=[speaker_id])
reload_model_button.click(reload_model, outputs=[model])
return model, speaker_id
if not hasattr(shared, "gradio_template_response_original"):
shared.gradio_template_response_original = gradio.routes.templates.TemplateResponse