brindhamanick commited on
Commit
cd81064
Β·
verified Β·
1 Parent(s): 99c19e1

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +180 -6
train.py CHANGED
@@ -1,8 +1,182 @@
1
- import time
 
 
 
 
 
 
 
 
 
2
 
3
- print("🎧 RVC Training Started...")
4
- for i in range(1, 11):
5
- print(f"Epoch {i}/10 -> Training in progress...")
6
- time.sleep(2)
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- print("βœ… Training completed! Model saved: weights/zeynep_rvc.pth")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchaudio
4
+ import numpy as np
5
+ from torch.utils.data import Dataset, DataLoader
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from pathlib import Path
9
+ import argparse
10
+ from torch.nn.utils.rnn import pad_sequence
11
 
12
+ # ---- RVC v2 Architecture (Hubert + Pitch + ContentVec) ----
13
+ class HubertEncoder(nn.Module):
14
+ def __init__(self, input_dim=1024, hidden_dim=768):
15
+ super().__init__()
16
+ self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1)
17
+ self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1)
18
+ self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, 2, batch_first=True, bidirectional=True)
19
+ self.proj = nn.Linear(hidden_dim, 256)
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, 2) # (B, T, F) -> (B, F, T)
23
+ x = torch.relu(self.conv1(x))
24
+ x = torch.relu(self.conv2(x))
25
+ x = x.transpose(1, 2) # Back to (B, T, F)
26
+ out, _ = self.lstm(x)
27
+ return self.proj(out) # 256-dim features
28
 
29
+ class PitchEncoder(nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.f0_conv = nn.Sequential(
33
+ nn.Conv1d(1, 64, 3, padding=1),
34
+ nn.ReLU(),
35
+ nn.Conv1d(64, 128, 3, padding=1),
36
+ nn.ReLU()
37
+ )
38
+ self.pitch_proj = nn.Linear(128, 256)
39
+
40
+ def forward(self, f0):
41
+ f0 = f0.unsqueeze(1).transpose(1, 2) # (B, T) -> (B, 1, T)
42
+ out = self.f0_conv(f0)
43
+ out = out.mean(-1) # Global avg pool
44
+ return self.pitch_proj(out)
45
+
46
+ class RVCDecoder(nn.Module):
47
+ def __init__(self, dim=256):
48
+ super().__init__()
49
+ self.content_lstm = nn.LSTM(dim, dim, 2, batch_first=True, bidirectional=True)
50
+ self.pitch_lstm = nn.LSTM(dim, dim//2, 1, batch_first=True)
51
+ self.fusion = nn.MultiheadAttention(dim*2, 8)
52
+ self.output_proj = nn.Sequential(
53
+ nn.Linear(dim*2, dim),
54
+ nn.ReLU(),
55
+ nn.Linear(dim, 1024) # Mel output
56
+ )
57
+
58
+ def forward(self, content, pitch):
59
+ content_out, _ = self.content_lstm(content)
60
+ pitch_out, _ = self.pitch_lstm(pitch)
61
+ pitch_out = pitch_out.repeat(1, content_out.size(1), 1)
62
+
63
+ fused, _ = self.fusion(content_out, pitch_out, content_out)
64
+ return self.output_proj(fused)
65
+
66
+ class RVCv2(nn.Module):
67
+ def __init__(self):
68
+ super().__init__()
69
+ self.hubert = HubertEncoder()
70
+ self.pitch = PitchEncoder()
71
+ self.decoder = RVCDecoder()
72
+
73
+ def forward(self, mel, f0):
74
+ content = self.hubert(mel)
75
+ pitch_feat = self.pitch(f0)
76
+ return self.decoder(content, pitch_feat)
77
+
78
+ # ---- Advanced Audio Dataset ----
79
+ class RVCv2Dataset(Dataset):
80
+ def __init__(self, dataset_dir, sample_rate=40000, duration=10):
81
+ self.files = list(Path(dataset_dir).glob("*.wav"))
82
+ self.sample_rate = sample_rate
83
+ self.duration = duration
84
+ self.n_samples = int(sample_rate * duration)
85
+
86
+ def __len__(self):
87
+ return len(self.files)
88
+
89
+ def __getitem__(self, idx):
90
+ waveform, sr = torchaudio.load(self.files[idx])
91
+
92
+ # Resample
93
+ if sr != self.sample_rate:
94
+ resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
95
+ waveform = resampler(waveform)
96
+
97
+ # Trim/pad
98
+ if waveform.shape[1] > self.n_samples:
99
+ waveform = waveform[:, :self.n_samples]
100
+ else:
101
+ waveform = torch.nn.functional.pad(waveform, (0, self.n_samples - waveform.shape[1]))
102
+
103
+ # Mel spectrogram (target)
104
+ mel_transform = torchaudio.transforms.MelSpectrogram(
105
+ sample_rate=self.sample_rate, n_mels=128, n_fft=2048, hop_length=512
106
+ )
107
+ mel = mel_transform(waveform).squeeze(0)
108
+ mel = torch.log(mel + 1e-9)
109
+
110
+ # Dummy F0 (real impl needs crepe/dio)
111
+ f0 = torch.ones(mel.shape[0]) * 200.0 # Placeholder
112
+ f0 = torch.tensor(f0).float()
113
+
114
+ return mel, f0, waveform
115
+
116
+ def collate_fn(batch):
117
+ mels, f0s, waves = zip(*batch)
118
+ mels = pad_sequence(mels, batch_first=True, padding_value=0.0)
119
+ f0s = pad_sequence(f0s.unsqueeze(1), batch_first=True, padding_value=0.0).squeeze(1)
120
+ return mels, f0s, waves
121
+
122
+ # ---- Training Loop ----
123
+ def train_rvc_v2(model_name, dataset_dir, sample_rate=40000, epochs=200, batch_size=8, lr=2e-4):
124
+ print(f"πŸš€ RVC v2 Training Started: {model_name}")
125
+ print(f"πŸ“‚ Dataset: {dataset_dir} ({len(os.listdir(dataset_dir))} files)")
126
+
127
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
+ print(f"πŸ› οΈ Device: {device}")
129
+
130
+ # Data
131
+ dataset = RVCv2Dataset(dataset_dir, sample_rate)
132
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
133
+
134
+ # Model
135
+ model = RVCv2().to(device)
136
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
137
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
138
+ criterion = nn.MSELoss()
139
+
140
+ os.makedirs("weights", exist_ok=True)
141
+
142
+ best_loss = float('inf')
143
+ for epoch in range(epochs):
144
+ model.train()
145
+ total_loss = 0
146
+
147
+ for batch_idx, (mel, f0, _) in enumerate(dataloader):
148
+ mel, f0 = mel.to(device), f0.to(device)
149
+
150
+ optimizer.zero_grad()
151
+ output = model(mel, f0)
152
+ loss = criterion(output, mel) # Reconstruction
153
+ loss.backward()
154
+
155
+ # Gradient clipping
156
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
157
+ optimizer.step()
158
+
159
+ total_loss += loss.item()
160
+
161
+ scheduler.step()
162
+ avg_loss = total_loss / len(dataloader)
163
+
164
+ if avg_loss < best_loss:
165
+ best_loss = avg_loss
166
+ torch.save(model.state_dict(), f"weights/{model_name}.pth")
167
+
168
+ if epoch % 10 == 0:
169
+ print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")
170
+
171
+ print(f"βœ… RVC v2 Training Complete! Best model: weights/{model_name}.pth")
172
+
173
+ if __name__ == "__main__":
174
+ parser = argparse.ArgumentParser(description="RVC v2 Training")
175
+ parser.add_argument("--model_name", required=True, help="Model name (e.g., zeynep_rvc)")
176
+ parser.add_argument("--dataset", required=True, help="Path to dataset folder")
177
+ parser.add_argument("--sample_rate", type=int, default=40000)
178
+ parser.add_argument("--epochs", type=int, default=200)
179
+ parser.add_argument("--batch_size", type=int, default=8)
180
+
181
+ args = parser.parse_args()
182
+ train_rvc_v2(args.model_name, args.dataset, args.sample_rate, args.epochs, args.batch_size)