aTrapDeer commited on
Commit
4936b4d
·
verified ·
1 Parent(s): 14850a9

Update train.py

Browse files

fixing tupple issue

Files changed (1) hide show
  1. train.py +192 -81
train.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import torch
3
  import torchaudio
4
- import numpy as np
5
  from torch.utils.data import Dataset, DataLoader
6
  import torch.nn as nn
7
  import torch.optim as optim
@@ -9,174 +8,286 @@ from pathlib import Path
9
  import argparse
10
  from torch.nn.utils.rnn import pad_sequence
11
 
12
- # ---- RVC v2 Architecture (Hubert + Pitch + ContentVec) ----
 
13
  class HubertEncoder(nn.Module):
14
- def __init__(self, input_dim=1024, hidden_dim=768):
15
  super().__init__()
 
16
  self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1)
17
  self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1)
18
- self.lstm = nn.LSTM(hidden_dim, hidden_dim//2, 2, batch_first=True, bidirectional=True)
 
 
 
 
 
 
19
  self.proj = nn.Linear(hidden_dim, 256)
20
-
21
  def forward(self, x):
22
- x = x.transpose(1, 2) # (B, T, F) -> (B, F, T)
 
23
  x = torch.relu(self.conv1(x))
24
  x = torch.relu(self.conv2(x))
25
- x = x.transpose(1, 2) # Back to (B, T, F)
26
  out, _ = self.lstm(x)
27
- return self.proj(out) # 256-dim features
 
28
 
29
  class PitchEncoder(nn.Module):
30
  def __init__(self):
31
  super().__init__()
 
32
  self.f0_conv = nn.Sequential(
33
  nn.Conv1d(1, 64, 3, padding=1),
34
  nn.ReLU(),
35
  nn.Conv1d(64, 128, 3, padding=1),
36
- nn.ReLU()
37
  )
 
38
  self.pitch_proj = nn.Linear(128, 256)
39
-
40
  def forward(self, f0):
41
- f0 = f0.unsqueeze(1).transpose(1, 2) # (B, T) -> (B, 1, T)
42
- out = self.f0_conv(f0)
43
- out = out.mean(-1) # Global avg pool
44
- return self.pitch_proj(out)
 
 
45
 
46
  class RVCDecoder(nn.Module):
47
- def __init__(self, dim=256):
48
  super().__init__()
49
- self.content_lstm = nn.LSTM(dim, dim, 2, batch_first=True, bidirectional=True)
50
- self.pitch_lstm = nn.LSTM(dim, dim//2, 1, batch_first=True)
51
- self.fusion = nn.MultiheadAttention(dim*2, 8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  self.output_proj = nn.Sequential(
53
- nn.Linear(dim*2, dim),
54
  nn.ReLU(),
55
- nn.Linear(dim, 1024) # Mel output
56
  )
57
-
58
  def forward(self, content, pitch):
59
- content_out, _ = self.content_lstm(content)
60
- pitch_out, _ = self.pitch_lstm(pitch)
61
- pitch_out = pitch_out.repeat(1, content_out.size(1), 1)
62
-
63
- fused, _ = self.fusion(content_out, pitch_out, content_out)
64
- return self.output_proj(fused)
 
 
 
 
 
 
 
 
65
 
66
  class RVCv2(nn.Module):
67
  def __init__(self):
68
  super().__init__()
69
- self.hubert = HubertEncoder()
70
  self.pitch = PitchEncoder()
71
- self.decoder = RVCDecoder()
72
-
73
  def forward(self, mel, f0):
 
 
74
  content = self.hubert(mel)
75
  pitch_feat = self.pitch(f0)
76
  return self.decoder(content, pitch_feat)
77
 
78
- # ---- Advanced Audio Dataset ----
 
79
  class RVCv2Dataset(Dataset):
80
  def __init__(self, dataset_dir, sample_rate=40000, duration=10):
81
  self.files = list(Path(dataset_dir).glob("*.wav"))
 
 
 
 
82
  self.sample_rate = sample_rate
83
  self.duration = duration
84
  self.n_samples = int(sample_rate * duration)
85
-
 
 
 
 
 
 
 
86
  def __len__(self):
87
  return len(self.files)
88
-
89
  def __getitem__(self, idx):
90
  waveform, sr = torchaudio.load(self.files[idx])
91
-
 
 
 
 
92
  # Resample
93
  if sr != self.sample_rate:
94
  resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
95
  waveform = resampler(waveform)
96
-
97
- # Trim/pad
98
  if waveform.shape[1] > self.n_samples:
99
  waveform = waveform[:, :self.n_samples]
100
  else:
101
- waveform = torch.nn.functional.pad(waveform, (0, self.n_samples - waveform.shape[1]))
102
-
103
- # Mel spectrogram (target)
104
- mel_transform = torchaudio.transforms.MelSpectrogram(
105
- sample_rate=self.sample_rate, n_mels=128, n_fft=2048, hop_length=512
106
- )
107
- mel = mel_transform(waveform).squeeze(0)
108
  mel = torch.log(mel + 1e-9)
109
-
110
- # Dummy F0 (real impl needs crepe/dio)
111
- f0 = torch.ones(mel.shape[0]) * 200.0 # Placeholder
112
- f0 = torch.tensor(f0).float()
113
-
 
 
114
  return mel, f0, waveform
115
 
 
116
  def collate_fn(batch):
117
  mels, f0s, waves = zip(*batch)
118
- mels = pad_sequence(mels, batch_first=True, padding_value=0.0)
119
- f0s = pad_sequence(f0s.unsqueeze(1), batch_first=True, padding_value=0.0).squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  return mels, f0s, waves
121
 
 
122
  # ---- Training Loop ----
123
- def train_rvc_v2(model_name, dataset_dir, sample_rate=40000, epochs=200, batch_size=8, lr=2e-4):
 
 
 
 
 
 
 
124
  print(f"🚀 RVC v2 Training Started: {model_name}")
125
- print(f"📂 Dataset: {dataset_dir} ({len(os.listdir(dataset_dir))} files)")
126
-
127
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
128
  print(f"🛠️ Device: {device}")
129
-
130
- # Data
131
  dataset = RVCv2Dataset(dataset_dir, sample_rate)
132
- dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
133
-
134
- # Model
 
 
 
 
 
 
 
 
135
  model = RVCv2().to(device)
136
- optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
137
- scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
 
 
 
 
 
 
 
 
 
 
138
  criterion = nn.MSELoss()
139
-
140
  os.makedirs("weights", exist_ok=True)
141
-
142
- best_loss = float('inf')
 
143
  for epoch in range(epochs):
144
  model.train()
145
- total_loss = 0
146
-
147
  for batch_idx, (mel, f0, _) in enumerate(dataloader):
148
- mel, f0 = mel.to(device), f0.to(device)
149
-
 
150
  optimizer.zero_grad()
 
151
  output = model(mel, f0)
152
- loss = criterion(output, mel) # Reconstruction
 
153
  loss.backward()
154
-
155
- # Gradient clipping
156
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 
157
  optimizer.step()
158
-
159
  total_loss += loss.item()
160
-
161
  scheduler.step()
 
162
  avg_loss = total_loss / len(dataloader)
163
-
164
  if avg_loss < best_loss:
165
  best_loss = avg_loss
166
  torch.save(model.state_dict(), f"weights/{model_name}.pth")
167
-
168
  if epoch % 10 == 0:
169
- print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")
170
-
171
- print(f"✅ RVC v2 Training Complete! Best model: weights/{model_name}.pth")
 
 
 
 
 
 
172
 
173
  if __name__ == "__main__":
174
  parser = argparse.ArgumentParser(description="RVC v2 Training")
175
- parser.add_argument("--model_name", required=True, help="Model name (e.g., zeynep_rvc)")
176
- parser.add_argument("--dataset", required=True, help="Path to dataset folder")
 
177
  parser.add_argument("--sample_rate", type=int, default=40000)
178
  parser.add_argument("--epochs", type=int, default=200)
179
  parser.add_argument("--batch_size", type=int, default=8)
180
-
181
  args = parser.parse_args()
182
- train_rvc_v2(args.model_name, args.dataset, args.sample_rate, args.epochs, args.batch_size)
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import torchaudio
 
4
  from torch.utils.data import Dataset, DataLoader
5
  import torch.nn as nn
6
  import torch.optim as optim
 
8
  import argparse
9
  from torch.nn.utils.rnn import pad_sequence
10
 
11
+
12
+ # ---- Simplified RVC-like Architecture ----
13
  class HubertEncoder(nn.Module):
14
+ def __init__(self, input_dim=128, hidden_dim=256):
15
  super().__init__()
16
+
17
  self.conv1 = nn.Conv1d(input_dim, hidden_dim, 3, padding=1)
18
  self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 3, padding=1)
19
+ self.lstm = nn.LSTM(
20
+ hidden_dim,
21
+ hidden_dim // 2,
22
+ num_layers=2,
23
+ batch_first=True,
24
+ bidirectional=True,
25
+ )
26
  self.proj = nn.Linear(hidden_dim, 256)
27
+
28
  def forward(self, x):
29
+ # x: (B, T, 128)
30
+ x = x.transpose(1, 2) # (B, 128, T)
31
  x = torch.relu(self.conv1(x))
32
  x = torch.relu(self.conv2(x))
33
+ x = x.transpose(1, 2) # (B, T, hidden)
34
  out, _ = self.lstm(x)
35
+ return self.proj(out) # (B, T, 256)
36
+
37
 
38
  class PitchEncoder(nn.Module):
39
  def __init__(self):
40
  super().__init__()
41
+
42
  self.f0_conv = nn.Sequential(
43
  nn.Conv1d(1, 64, 3, padding=1),
44
  nn.ReLU(),
45
  nn.Conv1d(64, 128, 3, padding=1),
46
+ nn.ReLU(),
47
  )
48
+
49
  self.pitch_proj = nn.Linear(128, 256)
50
+
51
  def forward(self, f0):
52
+ # f0: (B, T)
53
+ x = f0.unsqueeze(1) # (B, 1, T)
54
+ x = self.f0_conv(x) # (B, 128, T)
55
+ x = x.transpose(1, 2) # (B, T, 128)
56
+ return self.pitch_proj(x) # (B, T, 256)
57
+
58
 
59
  class RVCDecoder(nn.Module):
60
+ def __init__(self, dim=256, mel_dim=128):
61
  super().__init__()
62
+
63
+ self.content_lstm = nn.LSTM(
64
+ dim,
65
+ dim,
66
+ num_layers=2,
67
+ batch_first=True,
68
+ bidirectional=True,
69
+ )
70
+
71
+ self.pitch_proj = nn.Linear(dim, dim * 2)
72
+
73
+ self.fusion = nn.MultiheadAttention(
74
+ embed_dim=dim * 2,
75
+ num_heads=8,
76
+ batch_first=True,
77
+ )
78
+
79
  self.output_proj = nn.Sequential(
80
+ nn.Linear(dim * 2, dim),
81
  nn.ReLU(),
82
+ nn.Linear(dim, mel_dim),
83
  )
84
+
85
  def forward(self, content, pitch):
86
+ # content: (B, T, 256)
87
+ # pitch: (B, T, 256)
88
+
89
+ content_out, _ = self.content_lstm(content) # (B, T, 512)
90
+ pitch_out = self.pitch_proj(pitch) # (B, T, 512)
91
+
92
+ fused, _ = self.fusion(
93
+ query=content_out,
94
+ key=pitch_out,
95
+ value=content_out,
96
+ )
97
+
98
+ return self.output_proj(fused) # (B, T, 128)
99
+
100
 
101
  class RVCv2(nn.Module):
102
  def __init__(self):
103
  super().__init__()
104
+ self.hubert = HubertEncoder(input_dim=128)
105
  self.pitch = PitchEncoder()
106
+ self.decoder = RVCDecoder(dim=256, mel_dim=128)
107
+
108
  def forward(self, mel, f0):
109
+ # mel: (B, T, 128)
110
+ # f0: (B, T)
111
  content = self.hubert(mel)
112
  pitch_feat = self.pitch(f0)
113
  return self.decoder(content, pitch_feat)
114
 
115
+
116
+ # ---- Dataset ----
117
  class RVCv2Dataset(Dataset):
118
  def __init__(self, dataset_dir, sample_rate=40000, duration=10):
119
  self.files = list(Path(dataset_dir).glob("*.wav"))
120
+
121
+ if len(self.files) == 0:
122
+ raise ValueError(f"No .wav files found in {dataset_dir}")
123
+
124
  self.sample_rate = sample_rate
125
  self.duration = duration
126
  self.n_samples = int(sample_rate * duration)
127
+
128
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
129
+ sample_rate=self.sample_rate,
130
+ n_mels=128,
131
+ n_fft=2048,
132
+ hop_length=512,
133
+ )
134
+
135
  def __len__(self):
136
  return len(self.files)
137
+
138
  def __getitem__(self, idx):
139
  waveform, sr = torchaudio.load(self.files[idx])
140
+
141
+ # Convert stereo/multi-channel to mono
142
+ if waveform.shape[0] > 1:
143
+ waveform = waveform.mean(dim=0, keepdim=True)
144
+
145
  # Resample
146
  if sr != self.sample_rate:
147
  resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
148
  waveform = resampler(waveform)
149
+
150
+ # Trim/pad audio
151
  if waveform.shape[1] > self.n_samples:
152
  waveform = waveform[:, :self.n_samples]
153
  else:
154
+ pad_amount = self.n_samples - waveform.shape[1]
155
+ waveform = torch.nn.functional.pad(waveform, (0, pad_amount))
156
+
157
+ # Mel spectrogram: (1, 128, T) -> (128, T)
158
+ mel = self.mel_transform(waveform).squeeze(0)
 
 
159
  mel = torch.log(mel + 1e-9)
160
+
161
+ # Convert to (T, 128)
162
+ mel = mel.transpose(0, 1)
163
+
164
+ # Dummy F0 placeholder, one value per time frame
165
+ f0 = torch.ones(mel.shape[0], dtype=torch.float32) * 200.0
166
+
167
  return mel, f0, waveform
168
 
169
+
170
  def collate_fn(batch):
171
  mels, f0s, waves = zip(*batch)
172
+
173
+ # mels are list of tensors shaped (T, 128)
174
+ mels = pad_sequence(
175
+ mels,
176
+ batch_first=True,
177
+ padding_value=0.0,
178
+ )
179
+
180
+ # f0s are list of tensors shaped (T,)
181
+ f0s = pad_sequence(
182
+ f0s,
183
+ batch_first=True,
184
+ padding_value=0.0,
185
+ )
186
+
187
  return mels, f0s, waves
188
 
189
+
190
  # ---- Training Loop ----
191
+ def train_rvc_v2(
192
+ model_name,
193
+ dataset_dir,
194
+ sample_rate=40000,
195
+ epochs=200,
196
+ batch_size=8,
197
+ lr=2e-4,
198
+ ):
199
  print(f"🚀 RVC v2 Training Started: {model_name}")
200
+ print(f"📂 Dataset: {dataset_dir}")
201
+
202
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
203
  print(f"🛠️ Device: {device}")
204
+
 
205
  dataset = RVCv2Dataset(dataset_dir, sample_rate)
206
+
207
+ print(f"🎧 Files found: {len(dataset)}")
208
+
209
+ dataloader = DataLoader(
210
+ dataset,
211
+ batch_size=batch_size,
212
+ shuffle=True,
213
+ collate_fn=collate_fn,
214
+ num_workers=0,
215
+ )
216
+
217
  model = RVCv2().to(device)
218
+
219
+ optimizer = optim.AdamW(
220
+ model.parameters(),
221
+ lr=lr,
222
+ weight_decay=1e-5,
223
+ )
224
+
225
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(
226
+ optimizer,
227
+ T_max=epochs,
228
+ )
229
+
230
  criterion = nn.MSELoss()
231
+
232
  os.makedirs("weights", exist_ok=True)
233
+
234
+ best_loss = float("inf")
235
+
236
  for epoch in range(epochs):
237
  model.train()
238
+ total_loss = 0.0
239
+
240
  for batch_idx, (mel, f0, _) in enumerate(dataloader):
241
+ mel = mel.to(device)
242
+ f0 = f0.to(device)
243
+
244
  optimizer.zero_grad()
245
+
246
  output = model(mel, f0)
247
+
248
+ loss = criterion(output, mel)
249
  loss.backward()
250
+
 
251
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
252
+
253
  optimizer.step()
254
+
255
  total_loss += loss.item()
256
+
257
  scheduler.step()
258
+
259
  avg_loss = total_loss / len(dataloader)
260
+
261
  if avg_loss < best_loss:
262
  best_loss = avg_loss
263
  torch.save(model.state_dict(), f"weights/{model_name}.pth")
264
+
265
  if epoch % 10 == 0:
266
+ print(
267
+ f"Epoch {epoch}/{epochs} | "
268
+ f"Loss: {avg_loss:.4f} | "
269
+ f"Best: {best_loss:.4f} | "
270
+ f"LR: {scheduler.get_last_lr()[0]:.2e}"
271
+ )
272
+
273
+ print(f"✅ Training Complete! Best model: weights/{model_name}.pth")
274
+
275
 
276
  if __name__ == "__main__":
277
  parser = argparse.ArgumentParser(description="RVC v2 Training")
278
+
279
+ parser.add_argument("--model_name", required=True)
280
+ parser.add_argument("--dataset", required=True)
281
  parser.add_argument("--sample_rate", type=int, default=40000)
282
  parser.add_argument("--epochs", type=int, default=200)
283
  parser.add_argument("--batch_size", type=int, default=8)
284
+
285
  args = parser.parse_args()
286
+
287
+ train_rvc_v2(
288
+ model_name=args.model_name,
289
+ dataset_dir=args.dataset,
290
+ sample_rate=args.sample_rate,
291
+ epochs=args.epochs,
292
+ batch_size=args.batch_size,
293
+ )