File size: 6,344 Bytes
fcc038a
 
 
b7ae929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcc038a
5daffd4
fcc038a
b7ae929
 
fcc038a
5daffd4
b7ae929
fcc038a
5daffd4
b7ae929
fcc038a
 
 
 
b7ae929
 
 
 
 
 
5daffd4
fcc038a
b7ae929
5daffd4
b7ae929
3271a5e
b7ae929
 
 
 
 
fcc038a
b7ae929
fcc038a
b7ae929
fcc038a
 
b7ae929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcc038a
b7ae929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fcc038a
5daffd4
3271a5e
b7ae929
 
fcc038a
5daffd4
fcc038a
b7ae929
 
 
 
 
 
 
 
 
3271a5e
 
b7ae929
 
 
 
 
 
fcc038a
5daffd4
b7ae929
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3271a5e
fcc038a
b7ae929
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import gradio as gr
import os
import shutil

# Optional: needed if your Space is using ZeroGPU.
# If spaces is not installed or you are using a normal GPU Space, this safely falls back.
try:
    import spaces

    gpu_decorator = spaces.GPU(duration=3600)
except Exception:
    def gpu_decorator(func=None, **kwargs):
        if func is None:
            return lambda f: f
        return func


# ---- Helpers ----
def safe_model_name(name):
    if not name:
        return ""
    return "".join(c for c in name.strip() if c.isalnum() or c in ("-", "_"))


def find_weight_file(model_name):
    model_name = safe_model_name(model_name)

    possible_paths = [
        f"weights/{model_name}.pth",
        f"/app/weights/{model_name}.pth",
        f"/home/user/app/weights/{model_name}.pth",
        f"/data/weights/{model_name}.pth",
    ]

    for path in possible_paths:
        if os.path.exists(path):
            return path

    return None


# ---- Upload audio function ----
def upload_audio(audio_files, model_name):
    model_name = safe_model_name(model_name)

    if not audio_files:
        return "⚠️ Please upload at least one audio file."

    if not model_name:
        return "⚠️ Model name cannot be empty."

    save_dir = os.path.join("dataset", model_name)
    os.makedirs(save_dir, exist_ok=True)

    for i, file in enumerate(audio_files):
        original_name = getattr(file, "name", None)
        if original_name is None:
            continue

        shutil.copy(original_name, os.path.join(save_dir, f"{model_name}_{i}.wav"))

    return f"✅ {len(audio_files)} audio files saved to '{save_dir}' folder!"


# ---- Start training function ----
@gpu_decorator
def train_rvc(model_name, sample_rate, epochs, batch_size):
    model_name = safe_model_name(model_name)

    if not model_name:
        return "⚠️ Model name cannot be empty.", None

    dataset_dir = os.path.join("dataset", model_name)

    if not os.path.exists(dataset_dir):
        return f"⚠️ {dataset_dir} folder not found. Upload audio files first.", None

    try:
        from train import train_rvc_v2

        train_rvc_v2(
            model_name=model_name,
            dataset_dir=dataset_dir,
            sample_rate=int(sample_rate),
            epochs=int(epochs),
            batch_size=int(batch_size),
        )

        weight_path = find_weight_file(model_name)

        if weight_path:
            # Optional: copy to persistent storage if /data exists.
            # This helps future runs survive Space restarts if you enabled persistent storage.
            if os.path.exists("/data"):
                os.makedirs("/data/weights", exist_ok=True)
                persistent_path = f"/data/weights/{model_name}.pth"
                shutil.copy(weight_path, persistent_path)
                weight_path = persistent_path

            return (
                f"✅ Training complete!\n"
                f"Best model saved as: {weight_path}\n\n"
                f"Download it below.",
                weight_path,
            )

        return (
            f"⚠️ Training finished, but no weight file was found for model '{model_name}'.",
            None,
        )

    except Exception as e:
        return f"❌ Training failed: {type(e).__name__}: {e}", None


# ---- Download existing weight function ----
def download_existing_weight(model_name):
    model_name = safe_model_name(model_name)

    if not model_name:
        return "⚠️ Enter a model name first.", None

    weight_path = find_weight_file(model_name)

    if weight_path:
        return f"✅ Found existing weight file: {weight_path}", weight_path

    return (
        f"⚠️ No weight file found for '{model_name}'. "
        f"If the Space restarted, the old file may have been lost unless it was saved in /data.",
        None,
    )


# ---- Interface ----
with gr.Blocks(title="RVC v2 Training") as demo:
    gr.Markdown("# 🎙️ RVC Voice Model Training Tool")
    gr.Markdown("1️⃣ Upload audio → 2️⃣ Enter model name → 3️⃣ Start training → 4️⃣ Download `.pth` weight")

    with gr.Tab("Upload Audio"):
        with gr.Row():
            audio_files = gr.File(
                file_count="multiple",
                label="🎧 Upload Audio Files (.wav)"
            )
            model_name = gr.Textbox(
                label="Model Name",
                placeholder="e.g. ryan-trey"
            )

        output_upload = gr.Textbox(label="Status", lines=3)
        upload_button = gr.Button("📦 Upload Audio", variant="primary")

        upload_button.click(
            upload_audio,
            inputs=[audio_files, model_name],
            outputs=output_upload,
        )

    with gr.Tab("Start Training"):
        sample_rate = gr.Dropdown(
            choices=[32000, 40000, 48000],
            value=40000,
            label="Sample Rate (Hz)",
        )

        epochs = gr.Slider(
            50,
            1000,
            value=400,
            step=50,
            label="Number of Epochs",
        )

        batch_size = gr.Slider(
            1,
            16,
            value=4,
            step=1,
            label="Batch Size",
        )

        output_train = gr.Textbox(label="Training Status", lines=8)
        trained_weight_file = gr.File(label="⬇️ Download Trained Weight File")

        train_button = gr.Button("🚀 Start Training", variant="primary")

        train_button.click(
            train_rvc,
            inputs=[model_name, sample_rate, epochs, batch_size],
            outputs=[output_train, trained_weight_file],
        )

    with gr.Tab("Download Existing Model"):
        gr.Markdown(
            "Use this if training already finished and the `.pth` file still exists inside the running Space."
        )

        existing_status = gr.Textbox(label="Status", lines=4)
        existing_weight_file = gr.File(label="⬇️ Existing Weight File")

        download_existing_button = gr.Button("🔎 Find Existing Weight")

        download_existing_button.click(
            download_existing_weight,
            inputs=[model_name],
            outputs=[existing_status, existing_weight_file],
        )


if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        ssr_mode=False,
    )