| import gradio as gr |
| import os |
| import shutil |
|
|
| |
| |
| try: |
| import spaces |
|
|
| gpu_decorator = spaces.GPU(duration=3600) |
| except Exception: |
| def gpu_decorator(func=None, **kwargs): |
| if func is None: |
| return lambda f: f |
| return func |
|
|
|
|
| |
| def safe_model_name(name): |
| if not name: |
| return "" |
| return "".join(c for c in name.strip() if c.isalnum() or c in ("-", "_")) |
|
|
|
|
| def find_weight_file(model_name): |
| model_name = safe_model_name(model_name) |
|
|
| possible_paths = [ |
| f"weights/{model_name}.pth", |
| f"/app/weights/{model_name}.pth", |
| f"/home/user/app/weights/{model_name}.pth", |
| f"/data/weights/{model_name}.pth", |
| ] |
|
|
| for path in possible_paths: |
| if os.path.exists(path): |
| return path |
|
|
| return None |
|
|
|
|
| |
| def upload_audio(audio_files, model_name): |
| model_name = safe_model_name(model_name) |
|
|
| if not audio_files: |
| return "⚠️ Please upload at least one audio file." |
|
|
| if not model_name: |
| return "⚠️ Model name cannot be empty." |
|
|
| save_dir = os.path.join("dataset", model_name) |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| for i, file in enumerate(audio_files): |
| original_name = getattr(file, "name", None) |
| if original_name is None: |
| continue |
|
|
| shutil.copy(original_name, os.path.join(save_dir, f"{model_name}_{i}.wav")) |
|
|
| return f"✅ {len(audio_files)} audio files saved to '{save_dir}' folder!" |
|
|
|
|
| |
| @gpu_decorator |
| def train_rvc(model_name, sample_rate, epochs, batch_size): |
| model_name = safe_model_name(model_name) |
|
|
| if not model_name: |
| return "⚠️ Model name cannot be empty.", None |
|
|
| dataset_dir = os.path.join("dataset", model_name) |
|
|
| if not os.path.exists(dataset_dir): |
| return f"⚠️ {dataset_dir} folder not found. Upload audio files first.", None |
|
|
| try: |
| from train import train_rvc_v2 |
|
|
| train_rvc_v2( |
| model_name=model_name, |
| dataset_dir=dataset_dir, |
| sample_rate=int(sample_rate), |
| epochs=int(epochs), |
| batch_size=int(batch_size), |
| ) |
|
|
| weight_path = find_weight_file(model_name) |
|
|
| if weight_path: |
| |
| |
| if os.path.exists("/data"): |
| os.makedirs("/data/weights", exist_ok=True) |
| persistent_path = f"/data/weights/{model_name}.pth" |
| shutil.copy(weight_path, persistent_path) |
| weight_path = persistent_path |
|
|
| return ( |
| f"✅ Training complete!\n" |
| f"Best model saved as: {weight_path}\n\n" |
| f"Download it below.", |
| weight_path, |
| ) |
|
|
| return ( |
| f"⚠️ Training finished, but no weight file was found for model '{model_name}'.", |
| None, |
| ) |
|
|
| except Exception as e: |
| return f"❌ Training failed: {type(e).__name__}: {e}", None |
|
|
|
|
| |
| def download_existing_weight(model_name): |
| model_name = safe_model_name(model_name) |
|
|
| if not model_name: |
| return "⚠️ Enter a model name first.", None |
|
|
| weight_path = find_weight_file(model_name) |
|
|
| if weight_path: |
| return f"✅ Found existing weight file: {weight_path}", weight_path |
|
|
| return ( |
| f"⚠️ No weight file found for '{model_name}'. " |
| f"If the Space restarted, the old file may have been lost unless it was saved in /data.", |
| None, |
| ) |
|
|
|
|
| |
| with gr.Blocks(title="RVC v2 Training") as demo: |
| gr.Markdown("# 🎙️ RVC Voice Model Training Tool") |
| gr.Markdown("1️⃣ Upload audio → 2️⃣ Enter model name → 3️⃣ Start training → 4️⃣ Download `.pth` weight") |
|
|
| with gr.Tab("Upload Audio"): |
| with gr.Row(): |
| audio_files = gr.File( |
| file_count="multiple", |
| label="🎧 Upload Audio Files (.wav)" |
| ) |
| model_name = gr.Textbox( |
| label="Model Name", |
| placeholder="e.g. ryan-trey" |
| ) |
|
|
| output_upload = gr.Textbox(label="Status", lines=3) |
| upload_button = gr.Button("📦 Upload Audio", variant="primary") |
|
|
| upload_button.click( |
| upload_audio, |
| inputs=[audio_files, model_name], |
| outputs=output_upload, |
| ) |
|
|
| with gr.Tab("Start Training"): |
| sample_rate = gr.Dropdown( |
| choices=[32000, 40000, 48000], |
| value=40000, |
| label="Sample Rate (Hz)", |
| ) |
|
|
| epochs = gr.Slider( |
| 50, |
| 1000, |
| value=400, |
| step=50, |
| label="Number of Epochs", |
| ) |
|
|
| batch_size = gr.Slider( |
| 1, |
| 16, |
| value=4, |
| step=1, |
| label="Batch Size", |
| ) |
|
|
| output_train = gr.Textbox(label="Training Status", lines=8) |
| trained_weight_file = gr.File(label="⬇️ Download Trained Weight File") |
|
|
| train_button = gr.Button("🚀 Start Training", variant="primary") |
|
|
| train_button.click( |
| train_rvc, |
| inputs=[model_name, sample_rate, epochs, batch_size], |
| outputs=[output_train, trained_weight_file], |
| ) |
|
|
| with gr.Tab("Download Existing Model"): |
| gr.Markdown( |
| "Use this if training already finished and the `.pth` file still exists inside the running Space." |
| ) |
|
|
| existing_status = gr.Textbox(label="Status", lines=4) |
| existing_weight_file = gr.File(label="⬇️ Existing Weight File") |
|
|
| download_existing_button = gr.Button("🔎 Find Existing Weight") |
|
|
| download_existing_button.click( |
| download_existing_weight, |
| inputs=[model_name], |
| outputs=[existing_status, existing_weight_file], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| ssr_mode=False, |
| ) |