drixo commited on
Commit
f853855
·
verified ·
1 Parent(s): 90a21e7

Delete train.py

Browse files
Files changed (1) hide show
  1. train.py +0 -69
train.py DELETED
@@ -1,69 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torchaudio
4
- from torch.utils.data import DataLoader
5
- from datasets import load_dataset
6
- from model.model import RealtimeTTS
7
- from model.config import TTSConfig
8
- from model.tokenizer import TTSTokenizer
9
-
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
- config = TTSConfig()
12
-
13
- # Load tokenizer
14
- tokenizer = TTSTokenizer("tts_tokenizer.model")
15
-
16
- # Load dataset
17
- dataset = load_dataset("csv", data_files={"train": "train.csv"})["train"]
18
-
19
- mel_transform = torchaudio.transforms.MelSpectrogram(
20
- sample_rate=22050,
21
- n_mels=config.mel_bins
22
- )
23
-
24
- def preprocess(example):
25
- audio, sr = torchaudio.load(example["audio_path"])
26
- mel = mel_transform(audio).transpose(1, 2)
27
- tokens = tokenizer.encode(example["text"])
28
-
29
- return {
30
- "tokens": torch.tensor(tokens),
31
- "mel": mel.squeeze(0)
32
- }
33
-
34
- dataset = dataset.map(preprocess)
35
-
36
- def collate_fn(batch):
37
- tokens = [item["tokens"] for item in batch]
38
- mels = [item["mel"] for item in batch]
39
-
40
- tokens = nn.utils.rnn.pad_sequence(tokens, batch_first=True)
41
- mels = nn.utils.rnn.pad_sequence(mels, batch_first=True)
42
-
43
- return tokens, mels
44
-
45
- dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)
46
-
47
- model = RealtimeTTS(config).to(device)
48
- optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
49
- loss_fn = nn.MSELoss()
50
-
51
- # Training loop
52
- for epoch in range(10):
53
- model.train()
54
- for tokens, mels in dataloader:
55
- tokens = tokens.to(device)
56
- mels = mels.to(device)
57
-
58
- mel_input = torch.zeros_like(mels)
59
- output = model(tokens, mel_input)
60
-
61
- loss = loss_fn(output, mels)
62
-
63
- optimizer.zero_grad()
64
- loss.backward()
65
- optimizer.step()
66
-
67
- print(f"Epoch {epoch} Loss: {loss.item()}")
68
-
69
- torch.save(model.state_dict(), "model.pt")