New / train.py
sugakrit6's picture
Create train.py
eaf4c13 verified
import os
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import librosa
import soundfile as sf
import numpy as np
# -----------------------------
# Dataset
# -----------------------------
class AudioDataset(Dataset):
def __init__(self, dataset_path, sample_rate=44100, duration_sec=2.0):
self.sample_rate = sample_rate
self.duration = int(sample_rate * duration_sec)
self.files = sorted(glob.glob(os.path.join(dataset_path, "**/*.wav"), recursive=True))
if len(self.files) == 0:
raise ValueError("No WAV files found in dataset.")
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
path = self.files[idx]
audio, sr = sf.read(path)
# Resample if needed
if sr != self.sample_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate)
# Pad/trim
if len(audio) < self.duration:
audio = np.pad(audio, (0, self.duration - len(audio)))
else:
audio = audio[:self.duration]
audio = torch.tensor(audio, dtype=torch.float32)
return audio.unsqueeze(0) # (1, samples)
# -----------------------------
# Simple placeholder model
# (Replace with real RVC model)
# -----------------------------
class SimpleVoiceModel(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv1d(1, 32, 3, padding=1),
nn.ReLU(),
nn.Conv1d(32, 64, 3, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Conv1d(64, 32, 3, padding=1),
nn.ReLU(),
nn.Conv1d(32, 1, 3, padding=1)
)
def forward(self, x):
z = self.encoder(x)
out = self.decoder(z)
return out
# -----------------------------
# Training Loop
# -----------------------------
def train(
dataset_path="dataset",
batch_size=4,
epochs=5,
lr=1e-4,
save_every=1,
device="cuda" if torch.cuda.is_available() else "cpu"
):
print("Loading dataset...")
ds = AudioDataset(dataset_path)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True)
print("Initializing model...")
model = SimpleVoiceModel().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.L1Loss()
os.makedirs("checkpoints", exist_ok=True)
print("Training...")
for epoch in range(1, epochs + 1):
total_loss = 0
for batch in dl:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch)
loss = criterion(out, batch) # Autoencoder-style loss
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dl)
print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.6f}")
if epoch % save_every == 0:
ckpt_path = f"checkpoints/model_epoch_{epoch}.pt"
torch.save(model.state_dict(), ckpt_path)
print(f"Saved checkpoint: {ckpt_path}")
print("Training complete!")
# -----------------------------
# Entry point (called by app.py)
# -----------------------------
if __name__ == "__main__":
train(
dataset_path="dataset",
batch_size=4,
epochs=10,
lr=1e-4,
save_every=1
)