| | import os |
| | import glob |
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader |
| | import librosa |
| | import soundfile as sf |
| | import numpy as np |
| |
|
| | |
| | |
| | |
| | class AudioDataset(Dataset): |
| | def __init__(self, dataset_path, sample_rate=44100, duration_sec=2.0): |
| | self.sample_rate = sample_rate |
| | self.duration = int(sample_rate * duration_sec) |
| |
|
| | self.files = sorted(glob.glob(os.path.join(dataset_path, "**/*.wav"), recursive=True)) |
| | if len(self.files) == 0: |
| | raise ValueError("No WAV files found in dataset.") |
| |
|
| | def __len__(self): |
| | return len(self.files) |
| |
|
| | def __getitem__(self, idx): |
| | path = self.files[idx] |
| | audio, sr = sf.read(path) |
| |
|
| | |
| | if sr != self.sample_rate: |
| | audio = librosa.resample(audio, orig_sr=sr, target_sr=self.sample_rate) |
| |
|
| | |
| | if len(audio) < self.duration: |
| | audio = np.pad(audio, (0, self.duration - len(audio))) |
| | else: |
| | audio = audio[:self.duration] |
| |
|
| | audio = torch.tensor(audio, dtype=torch.float32) |
| | return audio.unsqueeze(0) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class SimpleVoiceModel(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.encoder = nn.Sequential( |
| | nn.Conv1d(1, 32, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv1d(32, 64, 3, padding=1), |
| | nn.ReLU(), |
| | ) |
| | self.decoder = nn.Sequential( |
| | nn.Conv1d(64, 32, 3, padding=1), |
| | nn.ReLU(), |
| | nn.Conv1d(32, 1, 3, padding=1) |
| | ) |
| |
|
| | def forward(self, x): |
| | z = self.encoder(x) |
| | out = self.decoder(z) |
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| | def train( |
| | dataset_path="dataset", |
| | batch_size=4, |
| | epochs=5, |
| | lr=1e-4, |
| | save_every=1, |
| | device="cuda" if torch.cuda.is_available() else "cpu" |
| | ): |
| | print("Loading dataset...") |
| | ds = AudioDataset(dataset_path) |
| | dl = DataLoader(ds, batch_size=batch_size, shuffle=True) |
| |
|
| | print("Initializing model...") |
| | model = SimpleVoiceModel().to(device) |
| | optimizer = optim.Adam(model.parameters(), lr=lr) |
| | criterion = nn.L1Loss() |
| |
|
| | os.makedirs("checkpoints", exist_ok=True) |
| |
|
| | print("Training...") |
| | for epoch in range(1, epochs + 1): |
| | total_loss = 0 |
| |
|
| | for batch in dl: |
| | batch = batch.to(device) |
| |
|
| | optimizer.zero_grad() |
| | out = model(batch) |
| | loss = criterion(out, batch) |
| | loss.backward() |
| | optimizer.step() |
| |
|
| | total_loss += loss.item() |
| |
|
| | avg_loss = total_loss / len(dl) |
| | print(f"Epoch {epoch}/{epochs} | Loss: {avg_loss:.6f}") |
| |
|
| | if epoch % save_every == 0: |
| | ckpt_path = f"checkpoints/model_epoch_{epoch}.pt" |
| | torch.save(model.state_dict(), ckpt_path) |
| | print(f"Saved checkpoint: {ckpt_path}") |
| |
|
| | print("Training complete!") |
| |
|
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | train( |
| | dataset_path="dataset", |
| | batch_size=4, |
| | epochs=10, |
| | lr=1e-4, |
| | save_every=1 |
| | ) |
| |
|