Train_RVC / app.py
aTrapDeer's picture
Update app.py
b7ae929 verified
import gradio as gr
import os
import shutil
# Optional: needed if your Space is using ZeroGPU.
# If spaces is not installed or you are using a normal GPU Space, this safely falls back.
try:
import spaces
gpu_decorator = spaces.GPU(duration=3600)
except Exception:
def gpu_decorator(func=None, **kwargs):
if func is None:
return lambda f: f
return func
# ---- Helpers ----
def safe_model_name(name):
if not name:
return ""
return "".join(c for c in name.strip() if c.isalnum() or c in ("-", "_"))
def find_weight_file(model_name):
model_name = safe_model_name(model_name)
possible_paths = [
f"weights/{model_name}.pth",
f"/app/weights/{model_name}.pth",
f"/home/user/app/weights/{model_name}.pth",
f"/data/weights/{model_name}.pth",
]
for path in possible_paths:
if os.path.exists(path):
return path
return None
# ---- Upload audio function ----
def upload_audio(audio_files, model_name):
model_name = safe_model_name(model_name)
if not audio_files:
return "⚠️ Please upload at least one audio file."
if not model_name:
return "⚠️ Model name cannot be empty."
save_dir = os.path.join("dataset", model_name)
os.makedirs(save_dir, exist_ok=True)
for i, file in enumerate(audio_files):
original_name = getattr(file, "name", None)
if original_name is None:
continue
shutil.copy(original_name, os.path.join(save_dir, f"{model_name}_{i}.wav"))
return f"✅ {len(audio_files)} audio files saved to '{save_dir}' folder!"
# ---- Start training function ----
@gpu_decorator
def train_rvc(model_name, sample_rate, epochs, batch_size):
model_name = safe_model_name(model_name)
if not model_name:
return "⚠️ Model name cannot be empty.", None
dataset_dir = os.path.join("dataset", model_name)
if not os.path.exists(dataset_dir):
return f"⚠️ {dataset_dir} folder not found. Upload audio files first.", None
try:
from train import train_rvc_v2
train_rvc_v2(
model_name=model_name,
dataset_dir=dataset_dir,
sample_rate=int(sample_rate),
epochs=int(epochs),
batch_size=int(batch_size),
)
weight_path = find_weight_file(model_name)
if weight_path:
# Optional: copy to persistent storage if /data exists.
# This helps future runs survive Space restarts if you enabled persistent storage.
if os.path.exists("/data"):
os.makedirs("/data/weights", exist_ok=True)
persistent_path = f"/data/weights/{model_name}.pth"
shutil.copy(weight_path, persistent_path)
weight_path = persistent_path
return (
f"✅ Training complete!\n"
f"Best model saved as: {weight_path}\n\n"
f"Download it below.",
weight_path,
)
return (
f"⚠️ Training finished, but no weight file was found for model '{model_name}'.",
None,
)
except Exception as e:
return f"❌ Training failed: {type(e).__name__}: {e}", None
# ---- Download existing weight function ----
def download_existing_weight(model_name):
model_name = safe_model_name(model_name)
if not model_name:
return "⚠️ Enter a model name first.", None
weight_path = find_weight_file(model_name)
if weight_path:
return f"✅ Found existing weight file: {weight_path}", weight_path
return (
f"⚠️ No weight file found for '{model_name}'. "
f"If the Space restarted, the old file may have been lost unless it was saved in /data.",
None,
)
# ---- Interface ----
with gr.Blocks(title="RVC v2 Training") as demo:
gr.Markdown("# 🎙️ RVC Voice Model Training Tool")
gr.Markdown("1️⃣ Upload audio → 2️⃣ Enter model name → 3️⃣ Start training → 4️⃣ Download `.pth` weight")
with gr.Tab("Upload Audio"):
with gr.Row():
audio_files = gr.File(
file_count="multiple",
label="🎧 Upload Audio Files (.wav)"
)
model_name = gr.Textbox(
label="Model Name",
placeholder="e.g. ryan-trey"
)
output_upload = gr.Textbox(label="Status", lines=3)
upload_button = gr.Button("📦 Upload Audio", variant="primary")
upload_button.click(
upload_audio,
inputs=[audio_files, model_name],
outputs=output_upload,
)
with gr.Tab("Start Training"):
sample_rate = gr.Dropdown(
choices=[32000, 40000, 48000],
value=40000,
label="Sample Rate (Hz)",
)
epochs = gr.Slider(
50,
1000,
value=400,
step=50,
label="Number of Epochs",
)
batch_size = gr.Slider(
1,
16,
value=4,
step=1,
label="Batch Size",
)
output_train = gr.Textbox(label="Training Status", lines=8)
trained_weight_file = gr.File(label="⬇️ Download Trained Weight File")
train_button = gr.Button("🚀 Start Training", variant="primary")
train_button.click(
train_rvc,
inputs=[model_name, sample_rate, epochs, batch_size],
outputs=[output_train, trained_weight_file],
)
with gr.Tab("Download Existing Model"):
gr.Markdown(
"Use this if training already finished and the `.pth` file still exists inside the running Space."
)
existing_status = gr.Textbox(label="Status", lines=4)
existing_weight_file = gr.File(label="⬇️ Existing Weight File")
download_existing_button = gr.Button("🔎 Find Existing Weight")
download_existing_button.click(
download_existing_weight,
inputs=[model_name],
outputs=[existing_status, existing_weight_file],
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
)