seemanthraju commited on
Commit
7be9079
·
0 Parent(s):

first commit

Browse files
Files changed (37) hide show
  1. .gitattributes +3 -0
  2. .gitignore +83 -0
  3. README.md +198 -0
  4. chiluka/__init__.py +9 -0
  5. chiluka/configs/config_ft.yml +116 -0
  6. chiluka/inference.py +368 -0
  7. chiluka/models/__init__.py +21 -0
  8. chiluka/models/core.py +731 -0
  9. chiluka/models/diffusion/__init__.py +22 -0
  10. chiluka/models/diffusion/diffusion.py +72 -0
  11. chiluka/models/diffusion/modules.py +367 -0
  12. chiluka/models/diffusion/sampler.py +176 -0
  13. chiluka/models/diffusion/utils.py +40 -0
  14. chiluka/models/hifigan.py +266 -0
  15. chiluka/pretrained/ASR/__init__.py +1 -0
  16. chiluka/pretrained/ASR/__pycache__/__init__.cpython-310.pyc +0 -0
  17. chiluka/pretrained/ASR/__pycache__/layers.cpython-310.pyc +0 -0
  18. chiluka/pretrained/ASR/__pycache__/models.cpython-310.pyc +0 -0
  19. chiluka/pretrained/ASR/config.yml +29 -0
  20. chiluka/pretrained/ASR/epoch_00080.pth +3 -0
  21. chiluka/pretrained/ASR/layers.py +354 -0
  22. chiluka/pretrained/ASR/models.py +186 -0
  23. chiluka/pretrained/JDC/__init__.py +1 -0
  24. chiluka/pretrained/JDC/__pycache__/__init__.cpython-310.pyc +0 -0
  25. chiluka/pretrained/JDC/__pycache__/model.cpython-310.pyc +0 -0
  26. chiluka/pretrained/JDC/bst.t7 +3 -0
  27. chiluka/pretrained/JDC/model.py +190 -0
  28. chiluka/pretrained/PLBERT/__pycache__/util.cpython-310.pyc +0 -0
  29. chiluka/pretrained/PLBERT/config.yml +30 -0
  30. chiluka/pretrained/PLBERT/step_1000000.t7 +3 -0
  31. chiluka/pretrained/PLBERT/util.py +42 -0
  32. chiluka/text_utils.py +24 -0
  33. chiluka/utils.py +21 -0
  34. examples/basic_synthesis.py +51 -0
  35. examples/telugu_synthesis.py +53 -0
  36. pyproject.toml +64 -0
  37. setup.py +60 -0
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Git LFS tracking for large model files
2
+ *.pth filter=lfs diff=lfs merge=lfs -text
3
+ *.t7 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # PyInstaller
28
+ *.manifest
29
+ *.spec
30
+
31
+ # Installer logs
32
+ pip-log.txt
33
+ pip-delete-this-directory.txt
34
+
35
+ # Unit test / coverage reports
36
+ htmlcov/
37
+ .tox/
38
+ .nox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *.cover
45
+ *.py,cover
46
+ .hypothesis/
47
+ .pytest_cache/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Environments
54
+ .env
55
+ .venv
56
+ env/
57
+ venv/
58
+ ENV/
59
+ env.bak/
60
+ venv.bak/
61
+ tests/
62
+ # IDE
63
+ .idea/
64
+ .vscode/
65
+ *.swp
66
+ *.swo
67
+
68
+ # Jupyter Notebook
69
+ .ipynb_checkpoints
70
+
71
+ # OS
72
+ .DS_Store
73
+ Thumbs.db
74
+
75
+ # Test outputs
76
+ test_outputs/
77
+ *.wav
78
+ !chiluka/pretrained/**
79
+
80
+ # Note: Large model files are tracked with Git LFS
81
+ # If not using Git LFS, uncomment these lines:
82
+ # *.pth
83
+ # *.t7
README.md ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Chiluka 🦜
2
+
3
+ **Chiluka** (చిలుక - Telugu for "parrot") is a self-contained TTS (Text-to-Speech) inference package based on StyleTTS2.
4
+
5
+ ## Features
6
+
7
+ - 🚀 Simple, clean API for TTS synthesis
8
+ - 📦 **Fully self-contained** - all models bundled in the package
9
+ - 🎙️ Style transfer from reference audio
10
+ - 🌍 Multi-language support via phonemizer
11
+ - 🔧 No external dependencies on other repos
12
+
13
+ ## Installation
14
+
15
+ ### From Source (Recommended)
16
+
17
+ ```bash
18
+ git clone https://github.com/yourusername/chiluka.git
19
+ cd chiluka
20
+ pip install -e .
21
+ ```
22
+
23
+ **Note:** This repo uses Git LFS for large model files. Make sure to install Git LFS first:
24
+
25
+ ```bash
26
+ # Ubuntu/Debian
27
+ sudo apt-get install git-lfs
28
+ git lfs install
29
+
30
+ # macOS
31
+ brew install git-lfs
32
+ git lfs install
33
+
34
+ # Then clone
35
+ git lfs clone https://github.com/yourusername/chiluka.git
36
+ ```
37
+
38
+ ### Install espeak-ng (Required for phonemization)
39
+
40
+ **Ubuntu/Debian:**
41
+ ```bash
42
+ sudo apt-get install espeak-ng
43
+ ```
44
+
45
+ **macOS:**
46
+ ```bash
47
+ brew install espeak-ng
48
+ ```
49
+
50
+ ## Quick Start
51
+
52
+ ```python
53
+ from chiluka import Chiluka
54
+
55
+ # Initialize - uses bundled models automatically!
56
+ tts = Chiluka()
57
+
58
+ # Synthesize speech
59
+ wav = tts.synthesize(
60
+ text="Hello, this is Chiluka speaking!",
61
+ reference_audio="path/to/reference.wav",
62
+ language="en"
63
+ )
64
+
65
+ # Save to file
66
+ tts.save_wav(wav, "output.wav")
67
+ ```
68
+
69
+ ### Telugu Example
70
+
71
+ ```python
72
+ from chiluka import Chiluka
73
+
74
+ tts = Chiluka()
75
+
76
+ wav = tts.synthesize(
77
+ text="నమస్కారం, నేను చిలుక మాట్లాడుతున్నాను",
78
+ reference_audio="path/to/telugu_reference.wav",
79
+ language="te" # Telugu
80
+ )
81
+
82
+ tts.save_wav(wav, "telugu_output.wav")
83
+ ```
84
+
85
+ ## Package Structure
86
+
87
+ ```
88
+ chiluka/
89
+ ├── chiluka/
90
+ │ ├── __init__.py
91
+ │ ├── inference.py # Main Chiluka API
92
+ │ ├── text_utils.py
93
+ │ ├── utils.py
94
+ │ ├── configs/
95
+ │ │ └── config_ft.yml # Model configuration
96
+ │ ├── checkpoints/
97
+ │ │ └── *.pth # Trained model checkpoint
98
+ │ ├── pretrained/
99
+ │ │ ├── ASR/ # Text aligner model
100
+ │ │ ├── JDC/ # Pitch extractor model
101
+ │ │ └── PLBERT/ # PL-BERT model
102
+ │ └── models/
103
+ │ ├── core.py
104
+ │ ├── hifigan.py
105
+ │ └── diffusion/
106
+ ├── examples/
107
+ │ ├── basic_synthesis.py
108
+ │ └── telugu_synthesis.py
109
+ ├── setup.py
110
+ ├── pyproject.toml
111
+ └── README.md
112
+ ```
113
+
114
+ ## API Reference
115
+
116
+ ### Chiluka Class
117
+
118
+ ```python
119
+ tts = Chiluka(
120
+ config_path=None, # Optional: custom config file
121
+ checkpoint_path=None, # Optional: custom checkpoint
122
+ pretrained_dir=None, # Optional: custom pretrained models
123
+ device=None # Optional: 'cuda' or 'cpu'
124
+ )
125
+ ```
126
+
127
+ ### synthesize()
128
+
129
+ ```python
130
+ wav = tts.synthesize(
131
+ text="Hello world", # Text to synthesize
132
+ reference_audio="ref.wav", # Reference audio for style
133
+ language="en", # Language code ('en', 'te', 'hi', etc.)
134
+ alpha=0.3, # Acoustic style mixing (0-1)
135
+ beta=0.7, # Prosodic style mixing (0-1)
136
+ diffusion_steps=5, # Diffusion sampling steps
137
+ embedding_scale=1.0, # Classifier-free guidance scale
138
+ sr=24000 # Sample rate
139
+ )
140
+ ```
141
+
142
+ ### Other Methods
143
+
144
+ ```python
145
+ # Save audio to file
146
+ tts.save_wav(wav, "output.wav", sr=24000)
147
+
148
+ # Play audio (requires pyaudio)
149
+ tts.play(wav, sr=24000)
150
+
151
+ # Get style embedding from audio
152
+ style = tts.compute_style("reference.wav", sr=24000)
153
+ ```
154
+
155
+ ## Synthesis Parameters
156
+
157
+ | Parameter | Default | Description |
158
+ |-----------|---------|-------------|
159
+ | `alpha` | 0.3 | Acoustic style mixing (0=reference only, 1=predicted only) |
160
+ | `beta` | 0.7 | Prosodic style mixing (0=reference only, 1=predicted only) |
161
+ | `diffusion_steps` | 5 | Number of diffusion sampling steps (more = better quality, slower) |
162
+ | `embedding_scale` | 1.0 | Classifier-free guidance scale |
163
+
164
+ ## Supported Languages
165
+
166
+ Uses [phonemizer](https://github.com/bootphon/phonemizer) with espeak-ng. Common languages:
167
+
168
+ | Language | Code |
169
+ |----------|------|
170
+ | English (US) | `en-us` |
171
+ | English (UK) | `en-gb` |
172
+ | Telugu | `te` |
173
+ | Hindi | `hi` |
174
+ | Tamil | `ta` |
175
+ | Kannada | `kn` |
176
+
177
+ See espeak-ng documentation for full list.
178
+
179
+ ## Requirements
180
+
181
+ - Python >= 3.8
182
+ - PyTorch >= 1.13.0
183
+ - CUDA (recommended for faster inference)
184
+ - espeak-ng
185
+
186
+ ## Training Your Own Model
187
+
188
+ This package is for **inference only**. To train your own model, use the original [StyleTTS2](https://github.com/yl4579/StyleTTS2) repository.
189
+
190
+ After training, copy your checkpoint to `chiluka/checkpoints/` and update the config if needed.
191
+
192
+ ## Credits
193
+
194
+ Based on [StyleTTS2](https://github.com/yl4579/StyleTTS2) by Yinghao Aaron Li et al.
195
+
196
+ ## License
197
+
198
+ MIT License
chiluka/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chiluka - A lightweight TTS inference package based on StyleTTS2
3
+ """
4
+
5
+ __version__ = "0.1.0"
6
+
7
+ from .inference import Chiluka
8
+
9
+ __all__ = ["Chiluka"]
chiluka/configs/config_ft.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/tm_tel_ft_24k"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 2
4
+ log_interval: 10
5
+ device: "cuda"
6
+
7
+ epochs_1st: 30
8
+ epochs_2nd: 20
9
+
10
+ batch_size: 2 # Keep at 2 with filtering
11
+ max_len: 200 # This is fine - refers to audio frames, not phonemes
12
+
13
+ pretrained_model: "/home/purview/Documents/TextToSpeech_Backup/StyleTTS2/Models/LibriTTS/epochs_2nd_00020.pth"
14
+
15
+ second_stage_load_pretrained: true
16
+ load_only_params: true
17
+
18
+ F0_path: "Utils/JDC/bst.t7"
19
+ ASR_config: "Utils/ASR/config.yml"
20
+ ASR_path: "Utils/ASR/epoch_00080.pth"
21
+ PLBERT_dir: "Utils/PLBERT/"
22
+
23
+ data_params:
24
+ train_data: "Data_custom/train_list.txt"
25
+ val_data: "Data_custom/val_list.txt"
26
+ root_path: "/home/purview/Documents/TextToSpeech_Backup/Processed_Dataset_24k/wavs"
27
+ OOD_data: "Data_custom/OOD_texts.txt"
28
+ min_length: 50 # <<<< This is in phonemes - keep it low
29
+
30
+ # Rest of your config stays the same...
31
+
32
+ preprocess_params:
33
+ sr: 24000
34
+ spect_params:
35
+ n_fft: 2048
36
+ win_length: 1200
37
+ hop_length: 300
38
+
39
+ model_params:
40
+ # match the LibriTTS checkpoint setting (it was trained multispeaker:true)
41
+ # You can still finetune with only speaker_id=0 in your train_list.txt
42
+ multispeaker: true
43
+
44
+ dim_in: 64
45
+ hidden_dim: 512
46
+ max_conv_dim: 512
47
+ n_layer: 3
48
+ n_mels: 80
49
+ n_token: 178
50
+ max_dur: 50
51
+ style_dim: 128
52
+ dropout: 0.2
53
+
54
+ # MUST MATCH LibriTTS CHECKPOINT (this is your main fix)
55
+ decoder:
56
+ type: "hifigan"
57
+ resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
58
+ resblock_kernel_sizes: [3, 7, 11]
59
+ upsample_initial_channel: 512
60
+ upsample_rates: [10, 5, 3, 2]
61
+ upsample_kernel_sizes: [20, 10, 6, 4]
62
+
63
+ slm:
64
+ model: "microsoft/wavlm-base-plus"
65
+ sr: 16000
66
+ hidden: 768
67
+ nlayers: 13
68
+ initial_channel: 64
69
+
70
+ diffusion:
71
+ embedding_mask_proba: 0.1
72
+ transformer:
73
+ num_layers: 3
74
+ num_heads: 8
75
+ head_features: 64
76
+ multiplier: 2
77
+ dist:
78
+ sigma_data: 0.19926648961191362
79
+ estimate_sigma_data: true
80
+ mean: -3.0
81
+ std: 1.0
82
+
83
+ loss_params:
84
+ lambda_mel: 5.0
85
+ lambda_gen: 1.0
86
+ lambda_slm: 1.0
87
+
88
+ lambda_mono: 1.0
89
+ lambda_s2s: 1.0
90
+ TMA_epoch: 4
91
+
92
+ lambda_F0: 1.0
93
+ lambda_norm: 1.0
94
+ lambda_dur: 1.0
95
+ lambda_ce: 20.0
96
+ lambda_sty: 1.0
97
+ lambda_diff: 1.0
98
+
99
+ # For a safe first run, delay diffusion + joint/SLM-adv.
100
+ # After it runs, you can set these back to 0 like LibriTTS.
101
+ diff_epoch: 999
102
+ joint_epoch: 999
103
+
104
+ optimizer_params:
105
+ lr: 0.0001
106
+ bert_lr: 0.00001
107
+ ft_lr: 0.00001
108
+
109
+ slmadv_params:
110
+ min_len: 400
111
+ max_len: 500
112
+ batch_percentage: 0.5
113
+ iter: 20
114
+ thresh: 5
115
+ scale: 0.01
116
+ sig: 1.5
chiluka/inference.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Chiluka - Main inference API for TTS synthesis.
3
+
4
+ Example usage:
5
+ from chiluka import Chiluka
6
+
7
+ # Simple usage (uses bundled models)
8
+ tts = Chiluka()
9
+
10
+ # Generate speech
11
+ wav = tts.synthesize(
12
+ text="Hello, world!",
13
+ reference_audio="path/to/reference.wav",
14
+ language="en"
15
+ )
16
+
17
+ # Save to file
18
+ tts.save_wav(wav, "output.wav")
19
+ """
20
+
21
+ import os
22
+ import yaml
23
+ import torch
24
+ import torchaudio
25
+ import librosa
26
+ import numpy as np
27
+ from pathlib import Path
28
+ from typing import Optional, Union
29
+
30
+ from nltk.tokenize import word_tokenize
31
+
32
+ from .models import build_model, load_ASR_models, load_F0_models, load_plbert
33
+ from .models.diffusion import DiffusionSampler, ADPM2Sampler, KarrasSchedule
34
+ from .text_utils import TextCleaner
35
+ from .utils import recursive_munch, length_to_mask
36
+
37
+
38
+ # Get package directory
39
+ PACKAGE_DIR = Path(__file__).parent.absolute()
40
+ DEFAULT_PRETRAINED_DIR = PACKAGE_DIR / "pretrained"
41
+ DEFAULT_CONFIG_PATH = PACKAGE_DIR / "configs" / "config_ft.yml"
42
+ DEFAULT_CHECKPOINT_DIR = PACKAGE_DIR / "checkpoints"
43
+
44
+
45
+ def get_default_checkpoint():
46
+ """Find the first checkpoint in the checkpoints directory."""
47
+ if DEFAULT_CHECKPOINT_DIR.exists():
48
+ checkpoints = list(DEFAULT_CHECKPOINT_DIR.glob("*.pth"))
49
+ if checkpoints:
50
+ return str(checkpoints[0])
51
+ return None
52
+
53
+
54
+ class Chiluka:
55
+ """
56
+ Chiluka TTS - Text-to-Speech synthesis using StyleTTS2.
57
+
58
+ Args:
59
+ config_path: Path to the YAML config file. If None, uses bundled config.
60
+ checkpoint_path: Path to the trained model checkpoint (.pth file). If None, uses bundled checkpoint.
61
+ pretrained_dir: Directory containing pretrained sub-models (ASR/, JDC/, PLBERT/). If None, uses bundled models.
62
+ device: Device to use ('cuda' or 'cpu'). If None, auto-detects.
63
+
64
+ Example:
65
+ # Use bundled models (simplest)
66
+ tts = Chiluka()
67
+
68
+ # Or specify custom paths
69
+ tts = Chiluka(
70
+ config_path="my_config.yml",
71
+ checkpoint_path="my_model.pth"
72
+ )
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ config_path: Optional[str] = None,
78
+ checkpoint_path: Optional[str] = None,
79
+ pretrained_dir: Optional[str] = None,
80
+ device: Optional[str] = None,
81
+ ):
82
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
83
+ print(f"Using device: {self.device}")
84
+
85
+ # Resolve paths - use bundled defaults if not specified
86
+ config_path = config_path or str(DEFAULT_CONFIG_PATH)
87
+ checkpoint_path = checkpoint_path or get_default_checkpoint()
88
+ pretrained_dir = pretrained_dir or str(DEFAULT_PRETRAINED_DIR)
89
+
90
+ if not checkpoint_path:
91
+ raise ValueError(
92
+ "No checkpoint found. Please either:\n"
93
+ "1. Place a .pth checkpoint in: {}\n"
94
+ "2. Specify checkpoint_path parameter".format(DEFAULT_CHECKPOINT_DIR)
95
+ )
96
+
97
+ # Load config
98
+ print(f"Loading config from {config_path}...")
99
+ with open(config_path, 'r') as f:
100
+ self.config = yaml.safe_load(f)
101
+
102
+ # Resolve pretrained paths
103
+ self.pretrained_dir = Path(pretrained_dir)
104
+ asr_config = self.pretrained_dir / "ASR" / "config.yml"
105
+ asr_path = self.pretrained_dir / "ASR" / "epoch_00080.pth"
106
+ f0_path = self.pretrained_dir / "JDC" / "bst.t7"
107
+ plbert_dir = self.pretrained_dir / "PLBERT"
108
+
109
+ # Verify pretrained models exist
110
+ self._verify_pretrained_models(asr_path, f0_path, plbert_dir)
111
+
112
+ # Load pretrained models
113
+ print(f"Loading ASR model...")
114
+ self.text_aligner = load_ASR_models(str(asr_path), str(asr_config))
115
+
116
+ print(f"Loading F0 model...")
117
+ self.pitch_extractor = load_F0_models(str(f0_path))
118
+
119
+ print(f"Loading PL-BERT...")
120
+ self.plbert = load_plbert(str(plbert_dir))
121
+
122
+ # Build model
123
+ self.model_params = recursive_munch(self.config["model_params"])
124
+ self.model = build_model(self.model_params, self.text_aligner, self.pitch_extractor, self.plbert)
125
+
126
+ # Load checkpoint
127
+ print(f"Loading checkpoint from {checkpoint_path}...")
128
+ self._load_checkpoint(checkpoint_path)
129
+
130
+ # Move to device and set to eval mode
131
+ for key in self.model:
132
+ self.model[key].eval().to(self.device)
133
+
134
+ # Build sampler
135
+ self.sampler = DiffusionSampler(
136
+ self.model.diffusion.diffusion,
137
+ sampler=ADPM2Sampler(),
138
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),
139
+ clamp=False,
140
+ )
141
+
142
+ # Text cleaner
143
+ self.textcleaner = TextCleaner()
144
+
145
+ # Mel spectrogram transform
146
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
147
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300
148
+ )
149
+
150
+ # Cache for phonemizer backends
151
+ self._phonemizers = {}
152
+
153
+ print("✓ Chiluka TTS initialized successfully!")
154
+
155
+ def _verify_pretrained_models(self, asr_path, f0_path, plbert_dir):
156
+ """Verify all pretrained models exist."""
157
+ missing = []
158
+ if not asr_path.exists():
159
+ missing.append(f"ASR model: {asr_path}")
160
+ if not f0_path.exists():
161
+ missing.append(f"F0 model: {f0_path}")
162
+ if not plbert_dir.exists():
163
+ missing.append(f"PLBERT directory: {plbert_dir}")
164
+
165
+ if missing:
166
+ raise FileNotFoundError(
167
+ "Missing pretrained models:\n" +
168
+ "\n".join(f" - {m}" for m in missing) +
169
+ f"\n\nExpected in: {self.pretrained_dir}"
170
+ )
171
+
172
+ def _load_checkpoint(self, checkpoint_path: str):
173
+ """Load model weights from checkpoint."""
174
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
175
+ for key in self.model:
176
+ if key in checkpoint["net"]:
177
+ try:
178
+ self.model[key].load_state_dict(checkpoint["net"][key])
179
+ except Exception:
180
+ state_dict = checkpoint["net"][key]
181
+ new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
182
+ self.model[key].load_state_dict(new_state_dict)
183
+
184
+ def _get_phonemizer(self, language: str):
185
+ """Get or create phonemizer backend for a language."""
186
+ if language not in self._phonemizers:
187
+ import phonemizer
188
+ self._phonemizers[language] = phonemizer.backend.EspeakBackend(
189
+ language=language, preserve_punctuation=True, with_stress=True
190
+ )
191
+ return self._phonemizers[language]
192
+
193
+ def _preprocess_mel(self, wave: np.ndarray, mean: float = -4, std: float = 4) -> torch.Tensor:
194
+ """Convert waveform to normalized mel spectrogram."""
195
+ wave_tensor = torch.from_numpy(wave).float()
196
+ mel_tensor = self.to_mel(wave_tensor)
197
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
198
+ return mel_tensor
199
+
200
+ def compute_style(self, audio_path: str, sr: int = 24000) -> torch.Tensor:
201
+ """
202
+ Compute style embedding from reference audio.
203
+
204
+ Args:
205
+ audio_path: Path to reference audio file
206
+ sr: Target sample rate
207
+
208
+ Returns:
209
+ Style embedding tensor
210
+ """
211
+ wave, orig_sr = librosa.load(audio_path, sr=sr)
212
+ audio, _ = librosa.effects.trim(wave, top_db=30)
213
+ if orig_sr != sr:
214
+ audio = librosa.resample(audio, orig_sr=orig_sr, target_sr=sr)
215
+
216
+ mel_tensor = self._preprocess_mel(audio).to(self.device)
217
+
218
+ with torch.no_grad():
219
+ ref_s = self.model.style_encoder(mel_tensor.unsqueeze(1))
220
+ ref_p = self.model.predictor_encoder(mel_tensor.unsqueeze(1))
221
+
222
+ return torch.cat([ref_s, ref_p], dim=1)
223
+
224
+ def synthesize(
225
+ self,
226
+ text: str,
227
+ reference_audio: str,
228
+ language: str = "en",
229
+ alpha: float = 0.3,
230
+ beta: float = 0.7,
231
+ diffusion_steps: int = 5,
232
+ embedding_scale: float = 1.0,
233
+ sr: int = 24000,
234
+ ) -> np.ndarray:
235
+ """
236
+ Synthesize speech from text.
237
+
238
+ Args:
239
+ text: Input text to synthesize
240
+ reference_audio: Path to reference audio for style transfer
241
+ language: Language code for phonemization (e.g., 'en', 'te', 'hi')
242
+ alpha: Style mixing coefficient for acoustic features (0-1)
243
+ beta: Style mixing coefficient for prosodic features (0-1)
244
+ diffusion_steps: Number of diffusion sampling steps
245
+ embedding_scale: Classifier-free guidance scale
246
+ sr: Sample rate
247
+
248
+ Returns:
249
+ Generated audio waveform as numpy array
250
+ """
251
+ # Compute style from reference
252
+ ref_s = self.compute_style(reference_audio, sr=sr)
253
+
254
+ # Phonemize text
255
+ phonemizer = self._get_phonemizer(language)
256
+ text = text.strip()
257
+ ps = phonemizer.phonemize([text])
258
+ ps = word_tokenize(ps[0])
259
+ ps = " ".join(ps)
260
+
261
+ # Convert to tokens
262
+ tokens = self.textcleaner(ps)
263
+ tokens.insert(0, 0) # Add start token
264
+ tokens = torch.LongTensor(tokens).to(self.device).unsqueeze(0)
265
+
266
+ # Truncate if too long
267
+ max_len = self.model.bert.config.max_position_embeddings
268
+ if tokens.shape[-1] > max_len:
269
+ tokens = tokens[:, :max_len]
270
+
271
+ with torch.no_grad():
272
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(self.device)
273
+ text_mask = length_to_mask(input_lengths).to(self.device)
274
+
275
+ # Encode text
276
+ t_en = self.model.text_encoder(tokens, input_lengths, text_mask)
277
+ bert_dur = self.model.bert(tokens, attention_mask=(~text_mask).int())
278
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
279
+
280
+ # Sample style
281
+ s_pred = self.sampler(
282
+ noise=torch.randn((1, 256)).unsqueeze(1).to(self.device),
283
+ embedding=bert_dur,
284
+ embedding_scale=embedding_scale,
285
+ features=ref_s,
286
+ num_steps=diffusion_steps,
287
+ ).squeeze(1)
288
+
289
+ s = s_pred[:, 128:]
290
+ ref = s_pred[:, :128]
291
+
292
+ # Mix styles
293
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
294
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
295
+
296
+ # Predict duration
297
+ d = self.model.predictor.text_encoder(d_en, s, input_lengths, text_mask)
298
+ x, _ = self.model.predictor.lstm(d)
299
+ duration = self.model.predictor.duration_proj(x)
300
+ duration = torch.sigmoid(duration).sum(axis=-1)
301
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
302
+
303
+ # Build alignment
304
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
305
+ c_frame = 0
306
+ for i in range(pred_aln_trg.size(0)):
307
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
308
+ c_frame += int(pred_dur[i].data)
309
+
310
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(self.device))
311
+
312
+ # Adjust for hifigan decoder
313
+ if self.model_params.decoder.type == "hifigan":
314
+ asr_new = torch.zeros_like(en)
315
+ asr_new[:, :, 0] = en[:, :, 0]
316
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
317
+ en = asr_new
318
+
319
+ # Predict F0 and energy
320
+ F0_pred, N_pred = self.model.predictor.F0Ntrain(en, s)
321
+
322
+ # Encode for decoder
323
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(self.device))
324
+ if self.model_params.decoder.type == "hifigan":
325
+ asr_new = torch.zeros_like(asr)
326
+ asr_new[:, :, 0] = asr[:, :, 0]
327
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
328
+ asr = asr_new
329
+
330
+ # Decode waveform
331
+ out = self.model.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
332
+
333
+ return out.squeeze().cpu().numpy()[..., :-50]
334
+
335
+ def save_wav(self, wav: np.ndarray, path: str, sr: int = 24000):
336
+ """
337
+ Save waveform to WAV file.
338
+
339
+ Args:
340
+ wav: Audio waveform as numpy array
341
+ path: Output file path
342
+ sr: Sample rate
343
+ """
344
+ import scipy.io.wavfile as wavfile
345
+ wav_int16 = (wav * 32767).clip(-32768, 32767).astype(np.int16)
346
+ wavfile.write(path, sr, wav_int16)
347
+ print(f"Saved audio to {path}")
348
+
349
+ def play(self, wav: np.ndarray, sr: int = 24000):
350
+ """
351
+ Play audio through speakers (requires pyaudio).
352
+
353
+ Args:
354
+ wav: Audio waveform as numpy array
355
+ sr: Sample rate
356
+ """
357
+ try:
358
+ import pyaudio
359
+ except ImportError:
360
+ raise ImportError("pyaudio is required for playback. Install with: pip install pyaudio")
361
+
362
+ audio_int16 = (wav * 32767.0).clip(-32768, 32767).astype("int16").tobytes()
363
+ p = pyaudio.PyAudio()
364
+ stream = p.open(format=pyaudio.paInt16, channels=1, rate=sr, output=True)
365
+ stream.write(audio_int16)
366
+ stream.stop_stream()
367
+ stream.close()
368
+ p.terminate()
chiluka/models/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model components for Chiluka TTS."""
2
+
3
+ from .core import (
4
+ build_model,
5
+ load_ASR_models,
6
+ load_F0_models,
7
+ load_plbert,
8
+ StyleEncoder,
9
+ TextEncoder,
10
+ ProsodyPredictor,
11
+ )
12
+
13
+ __all__ = [
14
+ "build_model",
15
+ "load_ASR_models",
16
+ "load_F0_models",
17
+ "load_plbert",
18
+ "StyleEncoder",
19
+ "TextEncoder",
20
+ "ProsodyPredictor",
21
+ ]
chiluka/models/core.py ADDED
@@ -0,0 +1,731 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core model definitions for Chiluka TTS."""
2
+
3
+ import os
4
+ import math
5
+ import yaml
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn.utils import weight_norm, spectral_norm
10
+ from collections import OrderedDict
11
+ from munch import Munch
12
+
13
+ from transformers import AlbertConfig, AlbertModel
14
+
15
+ from .diffusion.sampler import KDiffusion, LogNormalDistribution
16
+ from .diffusion.modules import Transformer1d, StyleTransformer1d
17
+ from .diffusion.diffusion import AudioDiffusionConditional
18
+ from .hifigan import Decoder
19
+
20
+
21
+ # ============== Style Encoder ==============
22
+
23
+ class DownSample(nn.Module):
24
+ def __init__(self, layer_type):
25
+ super().__init__()
26
+ self.layer_type = layer_type
27
+
28
+ def forward(self, x):
29
+ if self.layer_type == 'none':
30
+ return x
31
+ elif self.layer_type == 'timepreserve':
32
+ return F.avg_pool2d(x, (2, 1))
33
+ elif self.layer_type == 'half':
34
+ if x.shape[-1] % 2 != 0:
35
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
36
+ return F.avg_pool2d(x, 2)
37
+ else:
38
+ raise RuntimeError(f'Unexpected downsample type {self.layer_type}')
39
+
40
+
41
+ class LearnedDownSample(nn.Module):
42
+ def __init__(self, layer_type, dim_in):
43
+ super().__init__()
44
+ self.layer_type = layer_type
45
+ if self.layer_type == 'none':
46
+ self.conv = nn.Identity()
47
+ elif self.layer_type == 'timepreserve':
48
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 1), stride=(2, 1), groups=dim_in, padding=(1, 0)))
49
+ elif self.layer_type == 'half':
50
+ self.conv = spectral_norm(nn.Conv2d(dim_in, dim_in, kernel_size=(3, 3), stride=(2, 2), groups=dim_in, padding=1))
51
+ else:
52
+ raise RuntimeError(f'Unexpected downsample type {self.layer_type}')
53
+
54
+ def forward(self, x):
55
+ return self.conv(x)
56
+
57
+
58
+ class ResBlk(nn.Module):
59
+ def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), normalize=False, downsample='none'):
60
+ super().__init__()
61
+ self.actv = actv
62
+ self.normalize = normalize
63
+ self.downsample = DownSample(downsample)
64
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
65
+ self.learned_sc = dim_in != dim_out
66
+ self._build_weights(dim_in, dim_out)
67
+
68
+ def _build_weights(self, dim_in, dim_out):
69
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
70
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
71
+ if self.normalize:
72
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
73
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
74
+ if self.learned_sc:
75
+ self.conv1x1 = spectral_norm(nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False))
76
+
77
+ def _shortcut(self, x):
78
+ if self.learned_sc:
79
+ x = self.conv1x1(x)
80
+ if self.downsample:
81
+ x = self.downsample(x)
82
+ return x
83
+
84
+ def _residual(self, x):
85
+ if self.normalize:
86
+ x = self.norm1(x)
87
+ x = self.actv(x)
88
+ x = self.conv1(x)
89
+ x = self.downsample_res(x)
90
+ if self.normalize:
91
+ x = self.norm2(x)
92
+ x = self.actv(x)
93
+ x = self.conv2(x)
94
+ return x
95
+
96
+ def forward(self, x):
97
+ x = self._shortcut(x) + self._residual(x)
98
+ return x / math.sqrt(2)
99
+
100
+
101
+ class StyleEncoder(nn.Module):
102
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
103
+ super().__init__()
104
+ blocks = []
105
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
106
+ repeat_num = 4
107
+ for _ in range(repeat_num):
108
+ dim_out = min(dim_in * 2, max_conv_dim)
109
+ blocks += [ResBlk(dim_in, dim_out, downsample='half')]
110
+ dim_in = dim_out
111
+ blocks += [nn.LeakyReLU(0.2)]
112
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
113
+ blocks += [nn.AdaptiveAvgPool2d(1)]
114
+ blocks += [nn.LeakyReLU(0.2)]
115
+ self.shared = nn.Sequential(*blocks)
116
+ self.unshared = nn.Linear(dim_out, style_dim)
117
+
118
+ def forward(self, x):
119
+ h = self.shared(x)
120
+ h = h.view(h.size(0), -1)
121
+ s = self.unshared(h)
122
+ return s
123
+
124
+
125
+ # ============== Text Encoder ==============
126
+
127
+ class LayerNorm(nn.Module):
128
+ def __init__(self, channels, eps=1e-5):
129
+ super().__init__()
130
+ self.channels = channels
131
+ self.eps = eps
132
+ self.gamma = nn.Parameter(torch.ones(channels))
133
+ self.beta = nn.Parameter(torch.zeros(channels))
134
+
135
+ def forward(self, x):
136
+ x = x.transpose(1, -1)
137
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
138
+ return x.transpose(1, -1)
139
+
140
+
141
+ class LinearNorm(nn.Module):
142
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
143
+ super().__init__()
144
+ self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
145
+ nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
146
+
147
+ def forward(self, x):
148
+ return self.linear_layer(x)
149
+
150
+
151
+ class TextEncoder(nn.Module):
152
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
153
+ super().__init__()
154
+ self.embedding = nn.Embedding(n_symbols, channels)
155
+ padding = (kernel_size - 1) // 2
156
+ self.cnn = nn.ModuleList()
157
+ for _ in range(depth):
158
+ self.cnn.append(nn.Sequential(
159
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
160
+ LayerNorm(channels),
161
+ actv,
162
+ nn.Dropout(0.2),
163
+ ))
164
+ self.lstm = nn.LSTM(channels, channels // 2, 1, batch_first=True, bidirectional=True)
165
+
166
+ def forward(self, x, input_lengths, m):
167
+ x = self.embedding(x)
168
+ x = x.transpose(1, 2)
169
+ m = m.to(input_lengths.device).unsqueeze(1)
170
+ x.masked_fill_(m, 0.0)
171
+ for c in self.cnn:
172
+ x = c(x)
173
+ x.masked_fill_(m, 0.0)
174
+ x = x.transpose(1, 2)
175
+ input_lengths = input_lengths.cpu().numpy()
176
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
177
+ self.lstm.flatten_parameters()
178
+ x, _ = self.lstm(x)
179
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
180
+ x = x.transpose(-1, -2)
181
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
182
+ x_pad[:, :, :x.shape[-1]] = x
183
+ x = x_pad.to(x.device)
184
+ x.masked_fill_(m, 0.0)
185
+ return x
186
+
187
+
188
+ # ============== Prosody Predictor ==============
189
+
190
+ class AdaIN1d(nn.Module):
191
+ def __init__(self, style_dim, num_features):
192
+ super().__init__()
193
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
194
+ self.fc = nn.Linear(style_dim, num_features * 2)
195
+
196
+ def forward(self, x, s):
197
+ h = self.fc(s)
198
+ h = h.view(h.size(0), h.size(1), 1)
199
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
200
+ return (1 + gamma) * self.norm(x) + beta
201
+
202
+
203
+ class UpSample1d(nn.Module):
204
+ def __init__(self, layer_type):
205
+ super().__init__()
206
+ self.layer_type = layer_type
207
+
208
+ def forward(self, x):
209
+ if self.layer_type == 'none':
210
+ return x
211
+ else:
212
+ return F.interpolate(x, scale_factor=2, mode='nearest')
213
+
214
+
215
+ class AdainResBlk1d(nn.Module):
216
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
217
+ super().__init__()
218
+ self.actv = actv
219
+ self.upsample_type = upsample
220
+ self.upsample = UpSample1d(upsample)
221
+ self.learned_sc = dim_in != dim_out
222
+ self._build_weights(dim_in, dim_out, style_dim)
223
+ self.dropout = nn.Dropout(dropout_p)
224
+ if upsample == 'none':
225
+ self.pool = nn.Identity()
226
+ else:
227
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
228
+
229
+ def _build_weights(self, dim_in, dim_out, style_dim):
230
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
231
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
232
+ self.norm1 = AdaIN1d(style_dim, dim_in)
233
+ self.norm2 = AdaIN1d(style_dim, dim_out)
234
+ if self.learned_sc:
235
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
236
+
237
+ def _shortcut(self, x):
238
+ x = self.upsample(x)
239
+ if self.learned_sc:
240
+ x = self.conv1x1(x)
241
+ return x
242
+
243
+ def _residual(self, x, s):
244
+ x = self.norm1(x, s)
245
+ x = self.actv(x)
246
+ x = self.pool(x)
247
+ x = self.conv1(self.dropout(x))
248
+ x = self.norm2(x, s)
249
+ x = self.actv(x)
250
+ x = self.conv2(self.dropout(x))
251
+ return x
252
+
253
+ def forward(self, x, s):
254
+ out = self._residual(x, s)
255
+ out = (out + self._shortcut(x)) / math.sqrt(2)
256
+ return out
257
+
258
+
259
+ class AdaLayerNorm(nn.Module):
260
+ def __init__(self, style_dim, channels, eps=1e-5):
261
+ super().__init__()
262
+ self.channels = channels
263
+ self.eps = eps
264
+ self.fc = nn.Linear(style_dim, channels * 2)
265
+
266
+ def forward(self, x, s):
267
+ x = x.transpose(-1, -2)
268
+ x = x.transpose(1, -1)
269
+ h = self.fc(s)
270
+ h = h.view(h.size(0), h.size(1), 1)
271
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
272
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
273
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
274
+ x = (1 + gamma) * x + beta
275
+ return x.transpose(1, -1).transpose(-1, -2)
276
+
277
+
278
+ class DurationEncoder(nn.Module):
279
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
280
+ super().__init__()
281
+ self.lstms = nn.ModuleList()
282
+ for _ in range(nlayers):
283
+ self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
284
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
285
+ self.dropout = dropout
286
+ self.d_model = d_model
287
+ self.sty_dim = sty_dim
288
+
289
+ def forward(self, x, style, text_lengths, m):
290
+ masks = m.to(text_lengths.device)
291
+ x = x.permute(2, 0, 1)
292
+ s = style.expand(x.shape[0], x.shape[1], -1)
293
+ x = torch.cat([x, s], axis=-1)
294
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
295
+ x = x.transpose(0, 1)
296
+ input_lengths = text_lengths.cpu().numpy()
297
+ x = x.transpose(-1, -2)
298
+ for block in self.lstms:
299
+ if isinstance(block, AdaLayerNorm):
300
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
301
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
302
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
303
+ else:
304
+ x = x.transpose(-1, -2)
305
+ x = nn.utils.rnn.pack_padded_sequence(x, input_lengths, batch_first=True, enforce_sorted=False)
306
+ block.flatten_parameters()
307
+ x, _ = block(x)
308
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
309
+ x = F.dropout(x, p=self.dropout, training=self.training)
310
+ x = x.transpose(-1, -2)
311
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
312
+ x_pad[:, :, :x.shape[-1]] = x
313
+ x = x_pad.to(x.device)
314
+ return x.transpose(-1, -2)
315
+
316
+
317
+ class ProsodyPredictor(nn.Module):
318
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
319
+ super().__init__()
320
+ self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid, nlayers=nlayers, dropout=dropout)
321
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
322
+ self.duration_proj = LinearNorm(d_hid, max_dur)
323
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
324
+ self.F0 = nn.ModuleList()
325
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
326
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
327
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
328
+ self.N = nn.ModuleList()
329
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
330
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
331
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
332
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
333
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
334
+
335
+ def forward(self, texts, style, text_lengths, alignment, m):
336
+ d = self.text_encoder(texts, style, text_lengths, m)
337
+ input_lengths = text_lengths.cpu().numpy()
338
+ x = nn.utils.rnn.pack_padded_sequence(d, input_lengths, batch_first=True, enforce_sorted=False)
339
+ m = m.to(text_lengths.device).unsqueeze(1)
340
+ self.lstm.flatten_parameters()
341
+ x, _ = self.lstm(x)
342
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
343
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
344
+ x_pad[:, :x.shape[1], :] = x
345
+ x = x_pad.to(x.device)
346
+ duration = self.duration_proj(F.dropout(x, 0.5, training=self.training))
347
+ en = (d.transpose(-1, -2) @ alignment)
348
+ return duration.squeeze(-1), en
349
+
350
+ def F0Ntrain(self, x, s):
351
+ x, _ = self.shared(x.transpose(-1, -2))
352
+ F0 = x.transpose(-1, -2)
353
+ for block in self.F0:
354
+ F0 = block(F0, s)
355
+ F0 = self.F0_proj(F0)
356
+ N = x.transpose(-1, -2)
357
+ for block in self.N:
358
+ N = block(N, s)
359
+ N = self.N_proj(N)
360
+ return F0.squeeze(1), N.squeeze(1)
361
+
362
+
363
+ # ============== Pretrained Model Loaders ==============
364
+
365
+ class CustomAlbert(AlbertModel):
366
+ def forward(self, *args, **kwargs):
367
+ outputs = super().forward(*args, **kwargs)
368
+ return outputs.last_hidden_state
369
+
370
+
371
+ def load_plbert(log_dir):
372
+ """Load PL-BERT model from directory."""
373
+ config_path = os.path.join(log_dir, "config.yml")
374
+ plbert_config = yaml.safe_load(open(config_path))
375
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
376
+ bert = CustomAlbert(albert_base_configuration)
377
+ files = os.listdir(log_dir)
378
+ ckpts = [f for f in files if f.startswith("step_")]
379
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
380
+ iters = sorted(iters)[-1]
381
+ checkpoint = torch.load(os.path.join(log_dir, f"step_{iters}.t7"), map_location='cpu')
382
+ state_dict = checkpoint['net']
383
+ new_state_dict = OrderedDict()
384
+ for k, v in state_dict.items():
385
+ name = k[7:] # remove `module.`
386
+ if name.startswith('encoder.'):
387
+ name = name[8:] # remove `encoder.`
388
+ new_state_dict[name] = v
389
+ if "embeddings.position_ids" in new_state_dict:
390
+ del new_state_dict["embeddings.position_ids"]
391
+ bert.load_state_dict(new_state_dict, strict=False)
392
+ return bert
393
+
394
+
395
+ # ASR model components
396
+ import torchaudio
397
+ import torchaudio.functional as audio_F
398
+
399
+
400
+ class MFCC(nn.Module):
401
+ def __init__(self, n_mfcc=40, n_mels=80):
402
+ super().__init__()
403
+ self.n_mfcc = n_mfcc
404
+ self.n_mels = n_mels
405
+ self.norm = 'ortho'
406
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
407
+ self.register_buffer('dct_mat', dct_mat)
408
+
409
+ def forward(self, mel_specgram):
410
+ if len(mel_specgram.shape) == 2:
411
+ mel_specgram = mel_specgram.unsqueeze(0)
412
+ unsqueezed = True
413
+ else:
414
+ unsqueezed = False
415
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
416
+ if unsqueezed:
417
+ mfcc = mfcc.squeeze(0)
418
+ return mfcc
419
+
420
+
421
+ class ConvNorm(nn.Module):
422
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=None, dilation=1, bias=True, w_init_gain='linear'):
423
+ super().__init__()
424
+ if padding is None:
425
+ padding = int(dilation * (kernel_size - 1) / 2)
426
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
427
+ nn.init.xavier_uniform_(self.conv.weight, gain=nn.init.calculate_gain(w_init_gain))
428
+
429
+ def forward(self, signal):
430
+ return self.conv(signal)
431
+
432
+
433
+ class ConvBlock(nn.Module):
434
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
435
+ super().__init__()
436
+ self._n_groups = 8
437
+ self.blocks = nn.ModuleList([self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) for i in range(n_conv)])
438
+
439
+ def forward(self, x):
440
+ for block in self.blocks:
441
+ res = x
442
+ x = block(x)
443
+ x += res
444
+ return x
445
+
446
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
447
+ layers = [
448
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
449
+ nn.ReLU() if activ == 'relu' else nn.LeakyReLU(0.2),
450
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
451
+ nn.Dropout(p=dropout_p),
452
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
453
+ nn.ReLU() if activ == 'relu' else nn.LeakyReLU(0.2),
454
+ nn.Dropout(p=dropout_p)
455
+ ]
456
+ return nn.Sequential(*layers)
457
+
458
+
459
+ class LocationLayer(nn.Module):
460
+ def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
461
+ super().__init__()
462
+ padding = int((attention_kernel_size - 1) / 2)
463
+ self.location_conv = ConvNorm(2, attention_n_filters, kernel_size=attention_kernel_size, padding=padding, bias=False, stride=1, dilation=1)
464
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim, bias=False, w_init_gain='tanh')
465
+
466
+ def forward(self, attention_weights_cat):
467
+ processed_attention = self.location_conv(attention_weights_cat)
468
+ processed_attention = processed_attention.transpose(1, 2)
469
+ processed_attention = self.location_dense(processed_attention)
470
+ return processed_attention
471
+
472
+
473
+ class Attention(nn.Module):
474
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, attention_location_n_filters, attention_location_kernel_size):
475
+ super().__init__()
476
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, bias=False, w_init_gain='tanh')
477
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, w_init_gain='tanh')
478
+ self.v = LinearNorm(attention_dim, 1, bias=False)
479
+ self.location_layer = LocationLayer(attention_location_n_filters, attention_location_kernel_size, attention_dim)
480
+ self.score_mask_value = -float("inf")
481
+
482
+ def forward(self, attention_hidden_state, memory, processed_memory, attention_weights_cat, mask):
483
+ processed_query = self.query_layer(attention_hidden_state.unsqueeze(1))
484
+ processed_attention = self.location_layer(attention_weights_cat)
485
+ energies = self.v(torch.tanh(processed_query + processed_attention + processed_memory))
486
+ energies = energies.squeeze(-1)
487
+ if mask is not None:
488
+ energies.data.masked_fill_(mask, self.score_mask_value)
489
+ attention_weights = F.softmax(energies, dim=1)
490
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
491
+ attention_context = attention_context.squeeze(1)
492
+ return attention_context, attention_weights
493
+
494
+
495
+ class ASRS2S(nn.Module):
496
+ def __init__(self, embedding_dim=256, hidden_dim=512, n_location_filters=32, location_kernel_size=63, n_token=40):
497
+ super().__init__()
498
+ self.embedding = nn.Embedding(n_token, embedding_dim)
499
+ val_range = math.sqrt(6 / hidden_dim)
500
+ self.embedding.weight.data.uniform_(-val_range, val_range)
501
+ self.decoder_rnn_dim = hidden_dim
502
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
503
+ self.attention_layer = Attention(self.decoder_rnn_dim, hidden_dim, hidden_dim, n_location_filters, location_kernel_size)
504
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
505
+ self.project_to_hidden = nn.Sequential(LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), nn.Tanh())
506
+ self.sos = 1
507
+ self.eos = 2
508
+ self.unk_index = 3
509
+ self.random_mask = 0.1
510
+
511
+ def initialize_decoder_states(self, memory, mask):
512
+ B, L, H = memory.shape
513
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
514
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
515
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
516
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
517
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
518
+ self.memory = memory
519
+ self.processed_memory = self.attention_layer.memory_layer(memory)
520
+ self.mask = mask
521
+
522
+ def forward(self, memory, memory_mask, text_input):
523
+ self.initialize_decoder_states(memory, memory_mask)
524
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
525
+ _text_input = text_input.clone()
526
+ _text_input.masked_fill_(random_mask, self.unk_index)
527
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1)
528
+ start_embedding = self.embedding(torch.LongTensor([self.sos] * decoder_inputs.size(1)).to(decoder_inputs.device))
529
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
530
+ hidden_outputs, logit_outputs, alignments = [], [], []
531
+ while len(hidden_outputs) < decoder_inputs.size(0):
532
+ decoder_input = decoder_inputs[len(hidden_outputs)]
533
+ hidden, logit, attention_weights = self.decode(decoder_input)
534
+ hidden_outputs += [hidden]
535
+ logit_outputs += [logit]
536
+ alignments += [attention_weights]
537
+ hidden_outputs = torch.stack(hidden_outputs).transpose(0, 1).contiguous()
538
+ logit_outputs = torch.stack(logit_outputs).transpose(0, 1).contiguous()
539
+ alignments = torch.stack(alignments).transpose(0, 1)
540
+ return hidden_outputs, logit_outputs, alignments
541
+
542
+ def decode(self, decoder_input):
543
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
544
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(cell_input, (self.decoder_hidden, self.decoder_cell))
545
+ attention_weights_cat = torch.cat((self.attention_weights.unsqueeze(1), self.attention_weights_cum.unsqueeze(1)), dim=1)
546
+ self.attention_context, self.attention_weights = self.attention_layer(self.decoder_hidden, self.memory, self.processed_memory, attention_weights_cat, self.mask)
547
+ self.attention_weights_cum += self.attention_weights
548
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
549
+ hidden = self.project_to_hidden(hidden_and_context)
550
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
551
+ return hidden, logit, self.attention_weights
552
+
553
+
554
+ class ASRCNN(nn.Module):
555
+ def __init__(self, input_dim=80, hidden_dim=256, n_token=35, n_layers=6, token_embedding_dim=256):
556
+ super().__init__()
557
+ self.n_token = n_token
558
+ self.n_down = 1
559
+ self.to_mfcc = MFCC()
560
+ self.init_cnn = ConvNorm(input_dim // 2, hidden_dim, kernel_size=7, padding=3, stride=2)
561
+ self.cnns = nn.Sequential(*[nn.Sequential(ConvBlock(hidden_dim), nn.GroupNorm(num_groups=1, num_channels=hidden_dim)) for _ in range(n_layers)])
562
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
563
+ self.ctc_linear = nn.Sequential(LinearNorm(hidden_dim // 2, hidden_dim), nn.ReLU(), LinearNorm(hidden_dim, n_token))
564
+ self.asr_s2s = ASRS2S(embedding_dim=token_embedding_dim, hidden_dim=hidden_dim // 2, n_token=n_token)
565
+
566
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
567
+ x = self.to_mfcc(x)
568
+ x = self.init_cnn(x)
569
+ x = self.cnns(x)
570
+ x = self.projection(x)
571
+ x = x.transpose(1, 2)
572
+ ctc_logit = self.ctc_linear(x)
573
+ if text_input is not None:
574
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
575
+ return ctc_logit, s2s_logit, s2s_attn
576
+ else:
577
+ return ctc_logit
578
+
579
+
580
+ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
581
+ """Load ASR model."""
582
+ with open(ASR_MODEL_CONFIG) as f:
583
+ config = yaml.safe_load(f)
584
+ model_config = config['model_params']
585
+ model = ASRCNN(**model_config)
586
+ try:
587
+ ckpt = torch.load(ASR_MODEL_PATH, map_location="cpu", weights_only=False)
588
+ except TypeError:
589
+ ckpt = torch.load(ASR_MODEL_PATH, map_location="cpu")
590
+ params = ckpt["model"]
591
+ model.load_state_dict(params)
592
+ return model
593
+
594
+
595
+ # JDC (F0) model
596
+ class ResBlock_JDC(nn.Module):
597
+ def __init__(self, in_channels, out_channels, leaky_relu_slope=0.01):
598
+ super().__init__()
599
+ self.downsample = in_channels != out_channels
600
+ self.pre_conv = nn.Sequential(nn.BatchNorm2d(num_features=in_channels), nn.LeakyReLU(leaky_relu_slope, inplace=True), nn.MaxPool2d(kernel_size=(1, 2)))
601
+ self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.LeakyReLU(leaky_relu_slope, inplace=True), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False))
602
+ self.conv1by1 = None
603
+ if self.downsample:
604
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
605
+
606
+ def forward(self, x):
607
+ x = self.pre_conv(x)
608
+ if self.downsample:
609
+ x = self.conv(x) + self.conv1by1(x)
610
+ else:
611
+ x = self.conv(x) + x
612
+ return x
613
+
614
+
615
+ class JDCNet(nn.Module):
616
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
617
+ super().__init__()
618
+ self.num_class = num_class
619
+ self.conv_block = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(num_features=64), nn.LeakyReLU(leaky_relu_slope, inplace=True), nn.Conv2d(64, 64, 3, padding=1, bias=False))
620
+ self.res_block1 = ResBlock_JDC(in_channels=64, out_channels=128)
621
+ self.res_block2 = ResBlock_JDC(in_channels=128, out_channels=192)
622
+ self.res_block3 = ResBlock_JDC(in_channels=192, out_channels=256)
623
+ self.pool_block = nn.Sequential(nn.BatchNorm2d(num_features=256), nn.LeakyReLU(leaky_relu_slope, inplace=True), nn.MaxPool2d(kernel_size=(1, 4)), nn.Dropout(p=0.2))
624
+ # Maxpool layers for auxiliary network
625
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
626
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
627
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
628
+ # Detector conv
629
+ self.detector_conv = nn.Sequential(nn.Conv2d(640, 256, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(leaky_relu_slope, inplace=True), nn.Dropout(p=0.2))
630
+ # Classifier and detector LSTMs
631
+ self.bilstm_classifier = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=True)
632
+ self.bilstm_detector = nn.LSTM(input_size=512, hidden_size=256, batch_first=True, bidirectional=True)
633
+ # Output layers
634
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class)
635
+ self.detector = nn.Linear(in_features=512, out_features=2)
636
+
637
+ def forward(self, x):
638
+ seq_len = x.shape[-1]
639
+ x = x.float().transpose(-1, -2)
640
+ convblock_out = self.conv_block(x)
641
+ resblock1_out = self.res_block1(convblock_out)
642
+ resblock2_out = self.res_block2(resblock1_out)
643
+ resblock3_out = self.res_block3(resblock2_out)
644
+ poolblock_out = self.pool_block[0](resblock3_out)
645
+ poolblock_out = self.pool_block[1](poolblock_out)
646
+ GAN_feature = poolblock_out.transpose(-1, -2)
647
+ poolblock_out = self.pool_block[2](poolblock_out)
648
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
649
+ classifier_out, _ = self.bilstm_classifier(classifier_out)
650
+ classifier_out = classifier_out.contiguous().view((-1, 512))
651
+ classifier_out = self.classifier(classifier_out)
652
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class))
653
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
654
+
655
+
656
+ def load_F0_models(path):
657
+ """Load F0 (pitch) model."""
658
+ F0_model = JDCNet(num_class=1, seq_len=192)
659
+ params = torch.load(path, map_location='cpu')['net']
660
+ F0_model.load_state_dict(params)
661
+ return F0_model
662
+
663
+
664
+ # ============== Build Model ==============
665
+
666
+ def build_model(args, text_aligner, pitch_extractor, bert):
667
+ """Build the full TTS model."""
668
+ assert args.decoder.type in ['istftnet', 'hifigan'], 'Decoder type unknown'
669
+
670
+ decoder = Decoder(
671
+ dim_in=args.hidden_dim,
672
+ style_dim=args.style_dim,
673
+ dim_out=args.n_mels,
674
+ resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
675
+ upsample_rates=args.decoder.upsample_rates,
676
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
677
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
678
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes
679
+ )
680
+
681
+ text_encoder = TextEncoder(channels=args.hidden_dim, kernel_size=5, depth=args.n_layer, n_symbols=args.n_token)
682
+ predictor = ProsodyPredictor(style_dim=args.style_dim, d_hid=args.hidden_dim, nlayers=args.n_layer, max_dur=args.max_dur, dropout=args.dropout)
683
+ style_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)
684
+ predictor_encoder = StyleEncoder(dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim)
685
+
686
+ if args.multispeaker:
687
+ transformer = StyleTransformer1d(
688
+ channels=args.style_dim * 2,
689
+ context_embedding_features=bert.config.hidden_size,
690
+ context_features=args.style_dim * 2,
691
+ **args.diffusion.transformer
692
+ )
693
+ else:
694
+ transformer = Transformer1d(
695
+ channels=args.style_dim * 2,
696
+ context_embedding_features=bert.config.hidden_size,
697
+ **args.diffusion.transformer
698
+ )
699
+
700
+ diffusion = AudioDiffusionConditional(
701
+ in_channels=1,
702
+ embedding_max_length=bert.config.max_position_embeddings,
703
+ embedding_features=bert.config.hidden_size,
704
+ embedding_mask_proba=args.diffusion.embedding_mask_proba,
705
+ channels=args.style_dim * 2,
706
+ context_features=args.style_dim * 2,
707
+ )
708
+
709
+ diffusion.diffusion = KDiffusion(
710
+ net=diffusion.unet,
711
+ sigma_distribution=LogNormalDistribution(mean=args.diffusion.dist.mean, std=args.diffusion.dist.std),
712
+ sigma_data=args.diffusion.dist.sigma_data,
713
+ dynamic_threshold=0.0
714
+ )
715
+ diffusion.diffusion.net = transformer
716
+ diffusion.unet = transformer
717
+
718
+ nets = Munch(
719
+ bert=bert,
720
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
721
+ predictor=predictor,
722
+ decoder=decoder,
723
+ text_encoder=text_encoder,
724
+ predictor_encoder=predictor_encoder,
725
+ style_encoder=style_encoder,
726
+ diffusion=diffusion,
727
+ text_aligner=text_aligner,
728
+ pitch_extractor=pitch_extractor,
729
+ )
730
+
731
+ return nets
chiluka/models/diffusion/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diffusion model components."""
2
+
3
+ from .sampler import (
4
+ DiffusionSampler,
5
+ ADPM2Sampler,
6
+ KarrasSchedule,
7
+ KDiffusion,
8
+ LogNormalDistribution,
9
+ )
10
+ from .modules import Transformer1d, StyleTransformer1d
11
+ from .diffusion import AudioDiffusionConditional
12
+
13
+ __all__ = [
14
+ "DiffusionSampler",
15
+ "ADPM2Sampler",
16
+ "KarrasSchedule",
17
+ "KDiffusion",
18
+ "LogNormalDistribution",
19
+ "Transformer1d",
20
+ "StyleTransformer1d",
21
+ "AudioDiffusionConditional",
22
+ ]
chiluka/models/diffusion/diffusion.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Audio diffusion model classes."""
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from .utils import groupby
7
+ from .sampler import UniformDistribution
8
+
9
+
10
+ class LinearSchedule(nn.Module):
11
+ def forward(self, num_steps: int, device) -> Tensor:
12
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
13
+ return sigmas
14
+
15
+
16
+ class VSampler(nn.Module):
17
+ pass
18
+
19
+
20
+ class Model1d(nn.Module):
21
+ def __init__(self, unet_type: str = "base", **kwargs):
22
+ super().__init__()
23
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
24
+ self.unet = None
25
+ self.diffusion = None
26
+
27
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
28
+ return self.diffusion(x, **kwargs)
29
+
30
+ def sample(self, *args, **kwargs) -> Tensor:
31
+ return self.diffusion.sample(*args, **kwargs)
32
+
33
+
34
+ def get_default_model_kwargs():
35
+ return dict(
36
+ channels=128,
37
+ patch_size=16,
38
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
39
+ factors=[4, 4, 4, 2, 2, 2],
40
+ num_blocks=[2, 2, 2, 2, 2, 2],
41
+ attentions=[0, 0, 0, 1, 1, 1, 1],
42
+ attention_heads=8,
43
+ attention_features=64,
44
+ attention_multiplier=2,
45
+ attention_use_rel_pos=False,
46
+ diffusion_type="v",
47
+ diffusion_sigma_distribution=UniformDistribution(),
48
+ )
49
+
50
+
51
+ def get_default_sampling_kwargs():
52
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
53
+
54
+
55
+ class AudioDiffusionConditional(Model1d):
56
+ def __init__(self, embedding_features: int, embedding_max_length: int, embedding_mask_proba: float = 0.1, **kwargs):
57
+ self.embedding_mask_proba = embedding_mask_proba
58
+ default_kwargs = dict(
59
+ **get_default_model_kwargs(),
60
+ unet_type="cfg",
61
+ context_embedding_features=embedding_features,
62
+ context_embedding_max_length=embedding_max_length,
63
+ )
64
+ super().__init__(**{**default_kwargs, **kwargs})
65
+
66
+ def forward(self, *args, **kwargs):
67
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
68
+ return super().forward(*args, **{**default_kwargs, **kwargs})
69
+
70
+ def sample(self, *args, **kwargs):
71
+ default_kwargs = dict(**get_default_sampling_kwargs(), embedding_scale=5.0)
72
+ return super().sample(*args, **{**default_kwargs, **kwargs})
chiluka/models/diffusion/modules.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diffusion transformer modules."""
2
+
3
+ from math import log, pi
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, reduce, repeat
10
+ from einops.layers.torch import Rearrange
11
+ from torch import Tensor, einsum
12
+
13
+ from .utils import exists, default, rand_bool
14
+
15
+
16
+ class AdaLayerNorm(nn.Module):
17
+ def __init__(self, style_dim, channels, eps=1e-5):
18
+ super().__init__()
19
+ self.channels = channels
20
+ self.eps = eps
21
+ self.fc = nn.Linear(style_dim, channels * 2)
22
+
23
+ def forward(self, x, s):
24
+ x = x.transpose(-1, -2)
25
+ x = x.transpose(1, -1)
26
+ h = self.fc(s)
27
+ h = h.view(h.size(0), h.size(1), 1)
28
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
29
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
30
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
31
+ x = (1 + gamma) * x + beta
32
+ return x.transpose(1, -1).transpose(-1, -2)
33
+
34
+
35
+ class LearnedPositionalEmbedding(nn.Module):
36
+ def __init__(self, dim: int):
37
+ super().__init__()
38
+ assert (dim % 2) == 0
39
+ half_dim = dim // 2
40
+ self.weights = nn.Parameter(torch.randn(half_dim))
41
+
42
+ def forward(self, x: Tensor) -> Tensor:
43
+ x = rearrange(x, "b -> b 1")
44
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
45
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
46
+ fouriered = torch.cat((x, fouriered), dim=-1)
47
+ return fouriered
48
+
49
+
50
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
51
+ return nn.Sequential(
52
+ LearnedPositionalEmbedding(dim),
53
+ nn.Linear(in_features=dim + 1, out_features=out_features),
54
+ )
55
+
56
+
57
+ class FixedEmbedding(nn.Module):
58
+ def __init__(self, max_length: int, features: int):
59
+ super().__init__()
60
+ self.max_length = max_length
61
+ self.embedding = nn.Embedding(max_length, features)
62
+
63
+ def forward(self, x: Tensor) -> Tensor:
64
+ batch_size, length, device = *x.shape[0:2], x.device
65
+ assert length <= self.max_length, "Input sequence length must be <= max_length"
66
+ position = torch.arange(length, device=device)
67
+ fixed_embedding = self.embedding(position)
68
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
69
+ return fixed_embedding
70
+
71
+
72
+ class RelativePositionBias(nn.Module):
73
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
74
+ super().__init__()
75
+ self.num_buckets = num_buckets
76
+ self.max_distance = max_distance
77
+ self.num_heads = num_heads
78
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
79
+
80
+ @staticmethod
81
+ def _relative_position_bucket(relative_position: Tensor, num_buckets: int, max_distance: int):
82
+ num_buckets //= 2
83
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
84
+ n = torch.abs(relative_position)
85
+ max_exact = num_buckets // 2
86
+ is_small = n < max_exact
87
+ val_if_large = max_exact + (torch.log(n.float() / max_exact) / log(max_distance / max_exact) * (num_buckets - max_exact)).long()
88
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
89
+ ret += torch.where(is_small, n, val_if_large)
90
+ return ret
91
+
92
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
93
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
94
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
95
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
96
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
97
+ relative_position_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
98
+ bias = self.relative_attention_bias(relative_position_bucket)
99
+ bias = rearrange(bias, "m n h -> 1 h m n")
100
+ return bias
101
+
102
+
103
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
104
+ mid_features = features * multiplier
105
+ return nn.Sequential(
106
+ nn.Linear(in_features=features, out_features=mid_features),
107
+ nn.GELU(),
108
+ nn.Linear(in_features=mid_features, out_features=features),
109
+ )
110
+
111
+
112
+ class AttentionBase(nn.Module):
113
+ def __init__(self, features: int, *, head_features: int, num_heads: int, use_rel_pos: bool,
114
+ out_features: Optional[int] = None, rel_pos_num_buckets: Optional[int] = None,
115
+ rel_pos_max_distance: Optional[int] = None):
116
+ super().__init__()
117
+ self.scale = head_features ** -0.5
118
+ self.num_heads = num_heads
119
+ self.use_rel_pos = use_rel_pos
120
+ mid_features = head_features * num_heads
121
+ if use_rel_pos:
122
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
123
+ self.rel_pos = RelativePositionBias(num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance, num_heads=num_heads)
124
+ if out_features is None:
125
+ out_features = features
126
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
127
+
128
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
129
+ h = self.num_heads
130
+ q = rearrange(q, "b n (h d) -> b h n d", h=h)
131
+ k = rearrange(k, "b n (h d) -> b h n d", h=h)
132
+ v = rearrange(v, "b n (h d) -> b h n d", h=h)
133
+ sim = einsum("b h n d, b h m d -> b h n m", q, k)
134
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
135
+ sim = sim * self.scale
136
+ attn = sim.softmax(dim=-1)
137
+ out = einsum("b h n m, b h m d -> b h n d", attn, v)
138
+ out = rearrange(out, "b h n d -> b n (h d)")
139
+ return self.to_out(out)
140
+
141
+
142
+ class StyleAttention(nn.Module):
143
+ def __init__(self, features: int, *, style_dim: int, head_features: int, num_heads: int,
144
+ context_features: Optional[int] = None, use_rel_pos: bool,
145
+ rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None):
146
+ super().__init__()
147
+ self.context_features = context_features
148
+ mid_features = head_features * num_heads
149
+ context_features = default(context_features, features)
150
+ self.norm = AdaLayerNorm(style_dim, features)
151
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
152
+ self.to_q = nn.Linear(in_features=features, out_features=mid_features, bias=False)
153
+ self.to_kv = nn.Linear(in_features=context_features, out_features=mid_features * 2, bias=False)
154
+ self.attention = AttentionBase(features, num_heads=num_heads, head_features=head_features,
155
+ use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
156
+ rel_pos_max_distance=rel_pos_max_distance)
157
+
158
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
159
+ context = default(context, x)
160
+ x, context = self.norm(x, s), self.norm_context(context, s)
161
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
162
+ return self.attention(q, k, v)
163
+
164
+
165
+ class Attention(nn.Module):
166
+ def __init__(self, features: int, *, head_features: int, num_heads: int, out_features: Optional[int] = None,
167
+ context_features: Optional[int] = None, use_rel_pos: bool,
168
+ rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None):
169
+ super().__init__()
170
+ self.context_features = context_features
171
+ mid_features = head_features * num_heads
172
+ context_features = default(context_features, features)
173
+ self.norm = nn.LayerNorm(features)
174
+ self.norm_context = nn.LayerNorm(context_features)
175
+ self.to_q = nn.Linear(in_features=features, out_features=mid_features, bias=False)
176
+ self.to_kv = nn.Linear(in_features=context_features, out_features=mid_features * 2, bias=False)
177
+ self.attention = AttentionBase(features, out_features=out_features, num_heads=num_heads, head_features=head_features,
178
+ use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
179
+ rel_pos_max_distance=rel_pos_max_distance)
180
+
181
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
182
+ context = default(context, x)
183
+ x, context = self.norm(x), self.norm_context(context)
184
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
185
+ return self.attention(q, k, v)
186
+
187
+
188
+ class StyleTransformerBlock(nn.Module):
189
+ def __init__(self, features: int, num_heads: int, head_features: int, style_dim: int, multiplier: int,
190
+ use_rel_pos: bool, rel_pos_num_buckets: Optional[int] = None,
191
+ rel_pos_max_distance: Optional[int] = None, context_features: Optional[int] = None):
192
+ super().__init__()
193
+ self.use_cross_attention = exists(context_features) and context_features > 0
194
+ self.attention = StyleAttention(features=features, style_dim=style_dim, num_heads=num_heads, head_features=head_features,
195
+ use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
196
+ rel_pos_max_distance=rel_pos_max_distance)
197
+ if self.use_cross_attention:
198
+ self.cross_attention = StyleAttention(features=features, style_dim=style_dim, num_heads=num_heads, head_features=head_features,
199
+ context_features=context_features, use_rel_pos=use_rel_pos,
200
+ rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
201
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
202
+
203
+ def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
204
+ x = self.attention(x, s) + x
205
+ if self.use_cross_attention:
206
+ x = self.cross_attention(x, s, context=context) + x
207
+ x = self.feed_forward(x) + x
208
+ return x
209
+
210
+
211
+ class TransformerBlock(nn.Module):
212
+ def __init__(self, features: int, num_heads: int, head_features: int, multiplier: int, use_rel_pos: bool,
213
+ rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
214
+ context_features: Optional[int] = None):
215
+ super().__init__()
216
+ self.use_cross_attention = exists(context_features) and context_features > 0
217
+ self.attention = Attention(features=features, num_heads=num_heads, head_features=head_features,
218
+ use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
219
+ rel_pos_max_distance=rel_pos_max_distance)
220
+ if self.use_cross_attention:
221
+ self.cross_attention = Attention(features=features, num_heads=num_heads, head_features=head_features,
222
+ context_features=context_features, use_rel_pos=use_rel_pos,
223
+ rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
224
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
225
+
226
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
227
+ x = self.attention(x) + x
228
+ if self.use_cross_attention:
229
+ x = self.cross_attention(x, context=context) + x
230
+ x = self.feed_forward(x) + x
231
+ return x
232
+
233
+
234
+ class StyleTransformer1d(nn.Module):
235
+ def __init__(self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int,
236
+ use_context_time: bool = True, use_rel_pos: bool = False, context_features_multiplier: int = 1,
237
+ rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
238
+ context_features: Optional[int] = None, context_embedding_features: Optional[int] = None,
239
+ embedding_max_length: int = 512):
240
+ super().__init__()
241
+ self.blocks = nn.ModuleList([
242
+ StyleTransformerBlock(features=channels + context_embedding_features, head_features=head_features, num_heads=num_heads,
243
+ multiplier=multiplier, style_dim=context_features, use_rel_pos=use_rel_pos,
244
+ rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
245
+ for _ in range(num_layers)
246
+ ])
247
+ self.to_out = nn.Sequential(
248
+ Rearrange("b t c -> b c t"),
249
+ nn.Conv1d(in_channels=channels + context_embedding_features, out_channels=channels, kernel_size=1),
250
+ )
251
+ use_context_features = exists(context_features)
252
+ self.use_context_features = use_context_features
253
+ self.use_context_time = use_context_time
254
+ if use_context_time or use_context_features:
255
+ context_mapping_features = channels + context_embedding_features
256
+ self.to_mapping = nn.Sequential(nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(),
257
+ nn.Linear(context_mapping_features, context_mapping_features), nn.GELU())
258
+ if use_context_time:
259
+ self.to_time = nn.Sequential(TimePositionalEmbedding(dim=channels, out_features=context_mapping_features), nn.GELU())
260
+ if use_context_features:
261
+ self.to_features = nn.Sequential(nn.Linear(in_features=context_features, out_features=context_mapping_features), nn.GELU())
262
+ self.fixed_embedding = FixedEmbedding(max_length=embedding_max_length, features=context_embedding_features)
263
+
264
+ def get_mapping(self, time: Optional[Tensor] = None, features: Optional[Tensor] = None) -> Optional[Tensor]:
265
+ items, mapping = [], None
266
+ if self.use_context_time:
267
+ items += [self.to_time(time)]
268
+ if self.use_context_features:
269
+ items += [self.to_features(features)]
270
+ if self.use_context_time or self.use_context_features:
271
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
272
+ mapping = self.to_mapping(mapping)
273
+ return mapping
274
+
275
+ def run(self, x, time, embedding, features):
276
+ mapping = self.get_mapping(time, features)
277
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
278
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
279
+ for block in self.blocks:
280
+ x = x + mapping
281
+ x = block(x, features)
282
+ x = x.mean(axis=1).unsqueeze(1)
283
+ x = self.to_out(x)
284
+ x = x.transpose(-1, -2)
285
+ return x
286
+
287
+ def forward(self, x: Tensor, time: Tensor, embedding_mask_proba: float = 0.0, embedding: Optional[Tensor] = None,
288
+ features: Optional[Tensor] = None, embedding_scale: float = 1.0) -> Tensor:
289
+ b, device = embedding.shape[0], embedding.device
290
+ fixed_embedding = self.fixed_embedding(embedding)
291
+ if embedding_mask_proba > 0.0:
292
+ batch_mask = rand_bool(shape=(b, 1, 1), proba=embedding_mask_proba, device=device)
293
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
294
+ if embedding_scale != 1.0:
295
+ out = self.run(x, time, embedding=embedding, features=features)
296
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
297
+ return out_masked + (out - out_masked) * embedding_scale
298
+ else:
299
+ return self.run(x, time, embedding=embedding, features=features)
300
+
301
+
302
+ class Transformer1d(nn.Module):
303
+ def __init__(self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int,
304
+ use_context_time: bool = True, use_rel_pos: bool = False, context_features_multiplier: int = 1,
305
+ rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
306
+ context_features: Optional[int] = None, context_embedding_features: Optional[int] = None,
307
+ embedding_max_length: int = 512):
308
+ super().__init__()
309
+ self.blocks = nn.ModuleList([
310
+ TransformerBlock(features=channels + context_embedding_features, head_features=head_features, num_heads=num_heads,
311
+ multiplier=multiplier, use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
312
+ rel_pos_max_distance=rel_pos_max_distance)
313
+ for _ in range(num_layers)
314
+ ])
315
+ self.to_out = nn.Sequential(
316
+ Rearrange("b t c -> b c t"),
317
+ nn.Conv1d(in_channels=channels + context_embedding_features, out_channels=channels, kernel_size=1),
318
+ )
319
+ use_context_features = exists(context_features)
320
+ self.use_context_features = use_context_features
321
+ self.use_context_time = use_context_time
322
+ if use_context_time or use_context_features:
323
+ context_mapping_features = channels + context_embedding_features
324
+ self.to_mapping = nn.Sequential(nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(),
325
+ nn.Linear(context_mapping_features, context_mapping_features), nn.GELU())
326
+ if use_context_time:
327
+ self.to_time = nn.Sequential(TimePositionalEmbedding(dim=channels, out_features=context_mapping_features), nn.GELU())
328
+ if use_context_features:
329
+ self.to_features = nn.Sequential(nn.Linear(in_features=context_features, out_features=context_mapping_features), nn.GELU())
330
+ self.fixed_embedding = FixedEmbedding(max_length=embedding_max_length, features=context_embedding_features)
331
+
332
+ def get_mapping(self, time: Optional[Tensor] = None, features: Optional[Tensor] = None) -> Optional[Tensor]:
333
+ items, mapping = [], None
334
+ if self.use_context_time:
335
+ items += [self.to_time(time)]
336
+ if self.use_context_features:
337
+ items += [self.to_features(features)]
338
+ if self.use_context_time or self.use_context_features:
339
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
340
+ mapping = self.to_mapping(mapping)
341
+ return mapping
342
+
343
+ def run(self, x, time, embedding, features):
344
+ mapping = self.get_mapping(time, features)
345
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
346
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
347
+ for block in self.blocks:
348
+ x = x + mapping
349
+ x = block(x)
350
+ x = x.mean(axis=1).unsqueeze(1)
351
+ x = self.to_out(x)
352
+ x = x.transpose(-1, -2)
353
+ return x
354
+
355
+ def forward(self, x: Tensor, time: Tensor, embedding_mask_proba: float = 0.0, embedding: Optional[Tensor] = None,
356
+ features: Optional[Tensor] = None, embedding_scale: float = 1.0) -> Tensor:
357
+ b, device = embedding.shape[0], embedding.device
358
+ fixed_embedding = self.fixed_embedding(embedding)
359
+ if embedding_mask_proba > 0.0:
360
+ batch_mask = rand_bool(shape=(b, 1, 1), proba=embedding_mask_proba, device=device)
361
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
362
+ if embedding_scale != 1.0:
363
+ out = self.run(x, time, embedding=embedding, features=features)
364
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
365
+ return out_masked + (out - out_masked) * embedding_scale
366
+ else:
367
+ return self.run(x, time, embedding=embedding, features=features)
chiluka/models/diffusion/sampler.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diffusion sampling classes."""
2
+
3
+ from math import atan, cos, pi, sin, sqrt
4
+ from typing import Any, Callable, List, Optional, Tuple, Type
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, reduce
10
+ from torch import Tensor
11
+
12
+ from .utils import exists, default
13
+
14
+
15
+ class Distribution:
16
+ def __call__(self, num_samples: int, device: torch.device):
17
+ raise NotImplementedError()
18
+
19
+
20
+ class LogNormalDistribution(Distribution):
21
+ def __init__(self, mean: float, std: float):
22
+ self.mean = mean
23
+ self.std = std
24
+
25
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")) -> Tensor:
26
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
27
+ return normal.exp()
28
+
29
+
30
+ class UniformDistribution(Distribution):
31
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
32
+ return torch.rand(num_samples, device=device)
33
+
34
+
35
+ def to_batch(batch_size: int, device: torch.device, x: Optional[float] = None, xs: Optional[Tensor] = None) -> Tensor:
36
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
37
+ if exists(x):
38
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
39
+ assert exists(xs)
40
+ return xs
41
+
42
+
43
+ class Diffusion(nn.Module):
44
+ alias: str = ""
45
+
46
+ def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor:
47
+ raise NotImplementedError("Diffusion class missing denoise_fn")
48
+
49
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
50
+ raise NotImplementedError("Diffusion class missing forward function")
51
+
52
+
53
+ class KDiffusion(Diffusion):
54
+ """Elucidated Diffusion (Karras et al. 2022)"""
55
+
56
+ alias = "k"
57
+
58
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution, sigma_data: float, dynamic_threshold: float = 0.0):
59
+ super().__init__()
60
+ self.net = net
61
+ self.sigma_data = sigma_data
62
+ self.sigma_distribution = sigma_distribution
63
+ self.dynamic_threshold = dynamic_threshold
64
+
65
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
66
+ sigma_data = self.sigma_data
67
+ c_noise = torch.log(sigmas) * 0.25
68
+ sigmas = rearrange(sigmas, "b -> b 1 1")
69
+ c_skip = (sigma_data ** 2) / (sigmas ** 2 + sigma_data ** 2)
70
+ c_out = sigmas * sigma_data * (sigma_data ** 2 + sigmas ** 2) ** -0.5
71
+ c_in = (sigmas ** 2 + sigma_data ** 2) ** -0.5
72
+ return c_skip, c_out, c_in, c_noise
73
+
74
+ def denoise_fn(self, x_noisy: Tensor, sigmas: Optional[Tensor] = None, sigma: Optional[float] = None, **kwargs) -> Tensor:
75
+ batch_size, device = x_noisy.shape[0], x_noisy.device
76
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
77
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
78
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
79
+ x_denoised = c_skip * x_noisy + c_out * x_pred
80
+ return x_denoised
81
+
82
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
83
+ return (sigmas ** 2 + self.sigma_data ** 2) * (sigmas * self.sigma_data) ** -2
84
+
85
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
86
+ batch_size, device = x.shape[0], x.device
87
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
88
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
89
+ noise = default(noise, lambda: torch.randn_like(x))
90
+ x_noisy = x + sigmas_padded * noise
91
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
92
+ losses = F.mse_loss(x_denoised, x, reduction="none")
93
+ losses = reduce(losses, "b ... -> b", "mean")
94
+ losses = losses * self.loss_weight(sigmas)
95
+ return losses.mean()
96
+
97
+
98
+ class Schedule(nn.Module):
99
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
100
+ raise NotImplementedError()
101
+
102
+
103
+ class KarrasSchedule(Schedule):
104
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
105
+ super().__init__()
106
+ self.sigma_min = sigma_min
107
+ self.sigma_max = sigma_max
108
+ self.rho = rho
109
+
110
+ def forward(self, num_steps: int, device: Any) -> Tensor:
111
+ rho_inv = 1.0 / self.rho
112
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
113
+ sigmas = (
114
+ self.sigma_max ** rho_inv
115
+ + (steps / (num_steps - 1))
116
+ * (self.sigma_min ** rho_inv - self.sigma_max ** rho_inv)
117
+ ) ** self.rho
118
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
119
+ return sigmas
120
+
121
+
122
+ class Sampler(nn.Module):
123
+ diffusion_types: List[Type[Diffusion]] = []
124
+
125
+ def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor:
126
+ raise NotImplementedError()
127
+
128
+
129
+ class ADPM2Sampler(Sampler):
130
+ diffusion_types = [KDiffusion]
131
+
132
+ def __init__(self, rho: float = 1.0):
133
+ super().__init__()
134
+ self.rho = rho
135
+
136
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
137
+ r = self.rho
138
+ sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
139
+ sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
140
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
141
+ return sigma_up, sigma_down, sigma_mid
142
+
143
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
144
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
145
+ d = (x - fn(x, sigma=sigma)) / sigma
146
+ x_mid = x + d * (sigma_mid - sigma)
147
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
148
+ x = x + d_mid * (sigma_down - sigma)
149
+ x_next = x + torch.randn_like(x) * sigma_up
150
+ return x_next
151
+
152
+ def forward(self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int) -> Tensor:
153
+ x = sigmas[0] * noise
154
+ for i in range(num_steps - 1):
155
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1])
156
+ return x
157
+
158
+
159
+ class DiffusionSampler(nn.Module):
160
+ def __init__(self, diffusion: Diffusion, *, sampler: Sampler, sigma_schedule: Schedule, num_steps: Optional[int] = None, clamp: bool = True):
161
+ super().__init__()
162
+ self.denoise_fn = diffusion.denoise_fn
163
+ self.sampler = sampler
164
+ self.sigma_schedule = sigma_schedule
165
+ self.num_steps = num_steps
166
+ self.clamp = clamp
167
+
168
+ def forward(self, noise: Tensor, num_steps: Optional[int] = None, **kwargs) -> Tensor:
169
+ device = noise.device
170
+ num_steps = default(num_steps, self.num_steps)
171
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
172
+ sigmas = self.sigma_schedule(num_steps, device)
173
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs})
174
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
175
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
176
+ return x
chiluka/models/diffusion/utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Diffusion utility functions."""
2
+
3
+ from functools import reduce
4
+ from inspect import isfunction
5
+ from math import ceil, floor, log2
6
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
20
+ if exists(val):
21
+ return val
22
+ return d() if isfunction(d) else d
23
+
24
+
25
+ def rand_bool(shape, proba, device=None):
26
+ if proba == 1:
27
+ return torch.ones(shape, device=device, dtype=torch.bool)
28
+ elif proba == 0:
29
+ return torch.zeros(shape, device=device, dtype=torch.bool)
30
+ else:
31
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
32
+
33
+
34
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
35
+ kwargs_with_prefix = {k: v for k, v in d.items() if k.startswith(prefix)}
36
+ kwargs = {k: v for k, v in d.items() if not k.startswith(prefix)}
37
+ if keep_prefix:
38
+ return kwargs_with_prefix, kwargs
39
+ kwargs_no_prefix = {k[len(prefix):]: v for k, v in kwargs_with_prefix.items()}
40
+ return kwargs_no_prefix, kwargs
chiluka/models/hifigan.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """HiFi-GAN decoder for waveform synthesis."""
2
+
3
+ import math
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import Conv1d, ConvTranspose1d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ def init_weights(m, mean=0.0, std=0.01):
16
+ classname = m.__class__.__name__
17
+ if classname.find("Conv") != -1:
18
+ m.weight.data.normal_(mean, std)
19
+
20
+
21
+ def get_padding(kernel_size, dilation=1):
22
+ return int((kernel_size * dilation - dilation) / 2)
23
+
24
+
25
+ class AdaIN1d(nn.Module):
26
+ def __init__(self, style_dim, num_features):
27
+ super().__init__()
28
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
29
+ self.fc = nn.Linear(style_dim, num_features * 2)
30
+
31
+ def forward(self, x, s):
32
+ h = self.fc(s)
33
+ h = h.view(h.size(0), h.size(1), 1)
34
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
35
+ return (1 + gamma) * self.norm(x) + beta
36
+
37
+
38
+ class AdaINResBlock1(nn.Module):
39
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
40
+ super().__init__()
41
+ self.convs1 = nn.ModuleList([
42
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], padding=get_padding(kernel_size, dilation[0]))),
43
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], padding=get_padding(kernel_size, dilation[1]))),
44
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], padding=get_padding(kernel_size, dilation[2])))
45
+ ])
46
+ self.convs1.apply(init_weights)
47
+ self.convs2 = nn.ModuleList([
48
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
50
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)))
51
+ ])
52
+ self.convs2.apply(init_weights)
53
+ self.adain1 = nn.ModuleList([AdaIN1d(style_dim, channels) for _ in range(3)])
54
+ self.adain2 = nn.ModuleList([AdaIN1d(style_dim, channels) for _ in range(3)])
55
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for _ in range(len(self.convs1))])
56
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for _ in range(len(self.convs2))])
57
+
58
+ def forward(self, x, s):
59
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
60
+ xt = n1(x, s)
61
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2)
62
+ xt = c1(xt)
63
+ xt = n2(xt, s)
64
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2)
65
+ xt = c2(xt)
66
+ x = xt + x
67
+ return x
68
+
69
+
70
+ class SineGen(nn.Module):
71
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
72
+ super().__init__()
73
+ self.sine_amp = sine_amp
74
+ self.noise_std = noise_std
75
+ self.harmonic_num = harmonic_num
76
+ self.dim = harmonic_num + 1
77
+ self.sampling_rate = samp_rate
78
+ self.voiced_threshold = voiced_threshold
79
+ self.flag_for_pulse = flag_for_pulse
80
+ self.upsample_scale = upsample_scale
81
+
82
+ def _f02uv(self, f0):
83
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
84
+ return uv
85
+
86
+ def _f02sine(self, f0_values):
87
+ rad_values = (f0_values / self.sampling_rate) % 1
88
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
89
+ rand_ini[:, 0] = 0
90
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
91
+ rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
92
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
93
+ phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
94
+ sines = torch.sin(phase)
95
+ return sines
96
+
97
+ def forward(self, f0):
98
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
99
+ sine_waves = self._f02sine(fn) * self.sine_amp
100
+ uv = self._f02uv(f0)
101
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
102
+ noise = noise_amp * torch.randn_like(sine_waves)
103
+ sine_waves = sine_waves * uv + noise
104
+ return sine_waves, uv, noise
105
+
106
+
107
+ class SourceModuleHnNSF(nn.Module):
108
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0):
109
+ super().__init__()
110
+ self.sine_amp = sine_amp
111
+ self.noise_std = add_noise_std
112
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
113
+ self.l_linear = nn.Linear(harmonic_num + 1, 1)
114
+ self.l_tanh = nn.Tanh()
115
+
116
+ def forward(self, x):
117
+ with torch.no_grad():
118
+ sine_wavs, uv, _ = self.l_sin_gen(x)
119
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
120
+ noise = torch.randn_like(uv) * self.sine_amp / 3
121
+ return sine_merge, noise, uv
122
+
123
+
124
+ class Generator(nn.Module):
125
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes):
126
+ super().__init__()
127
+ self.num_kernels = len(resblock_kernel_sizes)
128
+ self.num_upsamples = len(upsample_rates)
129
+ resblock = AdaINResBlock1
130
+ self.m_source = SourceModuleHnNSF(sampling_rate=24000, upsample_scale=np.prod(upsample_rates), harmonic_num=8, voiced_threshod=10)
131
+ self.f0_upsamp = nn.Upsample(scale_factor=np.prod(upsample_rates))
132
+ self.noise_convs = nn.ModuleList()
133
+ self.ups = nn.ModuleList()
134
+ self.noise_res = nn.ModuleList()
135
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
136
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
137
+ self.ups.append(weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(u // 2 + u % 2), output_padding=u % 2)))
138
+ if i + 1 < len(upsample_rates):
139
+ stride_f0 = np.prod(upsample_rates[i + 1:])
140
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0 + 1) // 2))
141
+ self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim))
142
+ else:
143
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
144
+ self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim))
145
+ self.resblocks = nn.ModuleList()
146
+ self.alphas = nn.ParameterList()
147
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
148
+ for i in range(len(self.ups)):
149
+ ch = upsample_initial_channel // (2 ** (i + 1))
150
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
151
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
152
+ self.resblocks.append(resblock(ch, k, d, style_dim))
153
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
154
+ self.ups.apply(init_weights)
155
+ self.conv_post.apply(init_weights)
156
+
157
+ def forward(self, x, s, f0):
158
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
159
+ har_source, noi_source, uv = self.m_source(f0)
160
+ har_source = har_source.transpose(1, 2)
161
+ for i in range(self.num_upsamples):
162
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
163
+ x_source = self.noise_convs[i](har_source)
164
+ x_source = self.noise_res[i](x_source, s)
165
+ x = self.ups[i](x)
166
+ x = x + x_source
167
+ xs = None
168
+ for j in range(self.num_kernels):
169
+ if xs is None:
170
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
171
+ else:
172
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
173
+ x = xs / self.num_kernels
174
+ x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2)
175
+ x = self.conv_post(x)
176
+ x = torch.tanh(x)
177
+ return x
178
+
179
+
180
+ class UpSample1d(nn.Module):
181
+ def __init__(self, layer_type):
182
+ super().__init__()
183
+ self.layer_type = layer_type
184
+
185
+ def forward(self, x):
186
+ if self.layer_type == 'none':
187
+ return x
188
+ else:
189
+ return F.interpolate(x, scale_factor=2, mode='nearest')
190
+
191
+
192
+ class AdainResBlk1d(nn.Module):
193
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
194
+ super().__init__()
195
+ self.actv = actv
196
+ self.upsample_type = upsample
197
+ self.upsample = UpSample1d(upsample)
198
+ self.learned_sc = dim_in != dim_out
199
+ self._build_weights(dim_in, dim_out, style_dim)
200
+ self.dropout = nn.Dropout(dropout_p)
201
+ if upsample == 'none':
202
+ self.pool = nn.Identity()
203
+ else:
204
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
205
+
206
+ def _build_weights(self, dim_in, dim_out, style_dim):
207
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
208
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
209
+ self.norm1 = AdaIN1d(style_dim, dim_in)
210
+ self.norm2 = AdaIN1d(style_dim, dim_out)
211
+ if self.learned_sc:
212
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
213
+
214
+ def _shortcut(self, x):
215
+ x = self.upsample(x)
216
+ if self.learned_sc:
217
+ x = self.conv1x1(x)
218
+ return x
219
+
220
+ def _residual(self, x, s):
221
+ x = self.norm1(x, s)
222
+ x = self.actv(x)
223
+ x = self.pool(x)
224
+ x = self.conv1(self.dropout(x))
225
+ x = self.norm2(x, s)
226
+ x = self.actv(x)
227
+ x = self.conv2(self.dropout(x))
228
+ return x
229
+
230
+ def forward(self, x, s):
231
+ out = self._residual(x, s)
232
+ out = (out + self._shortcut(x)) / math.sqrt(2)
233
+ return out
234
+
235
+
236
+ class Decoder(nn.Module):
237
+ def __init__(self, dim_in=512, F0_channel=512, style_dim=64, dim_out=80, resblock_kernel_sizes=[3, 7, 11],
238
+ upsample_rates=[10, 5, 3, 2], upsample_initial_channel=512, resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
239
+ upsample_kernel_sizes=[20, 10, 6, 4]):
240
+ super().__init__()
241
+ self.decode = nn.ModuleList()
242
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
243
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
244
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
245
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
246
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
247
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
248
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
249
+ self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
250
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes)
251
+
252
+ def forward(self, asr, F0_curve, N, s):
253
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
254
+ N = self.N_conv(N.unsqueeze(1))
255
+ x = torch.cat([asr, F0, N], axis=1)
256
+ x = self.encode(x, s)
257
+ asr_res = self.asr_res(asr)
258
+ res = True
259
+ for block in self.decode:
260
+ if res:
261
+ x = torch.cat([x, asr_res, F0, N], axis=1)
262
+ x = block(x, s)
263
+ if block.upsample_type != "none":
264
+ res = False
265
+ x = self.generator(x, s, F0_curve)
266
+ return x
chiluka/pretrained/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
chiluka/pretrained/ASR/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
chiluka/pretrained/ASR/__pycache__/layers.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
chiluka/pretrained/ASR/__pycache__/models.cpython-310.pyc ADDED
Binary file (6.12 kB). View file
 
chiluka/pretrained/ASR/config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "logs/20201006"
2
+ save_freq: 5
3
+ device: "cuda"
4
+ epochs: 180
5
+ batch_size: 64
6
+ pretrained_model: ""
7
+ train_data: "ASRDataset/train_list.txt"
8
+ val_data: "ASRDataset/val_list.txt"
9
+
10
+ dataset_params:
11
+ data_augmentation: false
12
+
13
+ preprocess_parasm:
14
+ sr: 24000
15
+ spect_params:
16
+ n_fft: 2048
17
+ win_length: 1200
18
+ hop_length: 300
19
+ mel_params:
20
+ n_mels: 80
21
+
22
+ model_params:
23
+ input_dim: 80
24
+ hidden_dim: 256
25
+ n_token: 178
26
+ token_embedding_dim: 512
27
+
28
+ optimizer_params:
29
+ lr: 0.0005
chiluka/pretrained/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
chiluka/pretrained/ASR/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
chiluka/pretrained/ASR/models.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+ class ASRCNN(nn.Module):
9
+ def __init__(self,
10
+ input_dim=80,
11
+ hidden_dim=256,
12
+ n_token=35,
13
+ n_layers=6,
14
+ token_embedding_dim=256,
15
+
16
+ ):
17
+ super().__init__()
18
+ self.n_token = n_token
19
+ self.n_down = 1
20
+ self.to_mfcc = MFCC()
21
+ self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
22
+ self.cnns = nn.Sequential(
23
+ *[nn.Sequential(
24
+ ConvBlock(hidden_dim),
25
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
26
+ ) for n in range(n_layers)])
27
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
28
+ self.ctc_linear = nn.Sequential(
29
+ LinearNorm(hidden_dim//2, hidden_dim),
30
+ nn.ReLU(),
31
+ LinearNorm(hidden_dim, n_token))
32
+ self.asr_s2s = ASRS2S(
33
+ embedding_dim=token_embedding_dim,
34
+ hidden_dim=hidden_dim//2,
35
+ n_token=n_token)
36
+
37
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
38
+ x = self.to_mfcc(x)
39
+ x = self.init_cnn(x)
40
+ x = self.cnns(x)
41
+ x = self.projection(x)
42
+ x = x.transpose(1, 2)
43
+ ctc_logit = self.ctc_linear(x)
44
+ if text_input is not None:
45
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
46
+ return ctc_logit, s2s_logit, s2s_attn
47
+ else:
48
+ return ctc_logit
49
+
50
+ def get_feature(self, x):
51
+ x = self.to_mfcc(x.squeeze(1))
52
+ x = self.init_cnn(x)
53
+ x = self.cnns(x)
54
+ x = self.projection(x)
55
+ return x
56
+
57
+ def length_to_mask(self, lengths):
58
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
59
+ mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
60
+ return mask
61
+
62
+ def get_future_mask(self, out_length, unmask_future_steps=0):
63
+ """
64
+ Args:
65
+ out_length (int): returned mask shape is (out_length, out_length).
66
+ unmask_futre_steps (int): unmasking future step size.
67
+ Return:
68
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
69
+ """
70
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
71
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
72
+ return mask
73
+
74
+ class ASRS2S(nn.Module):
75
+ def __init__(self,
76
+ embedding_dim=256,
77
+ hidden_dim=512,
78
+ n_location_filters=32,
79
+ location_kernel_size=63,
80
+ n_token=40):
81
+ super(ASRS2S, self).__init__()
82
+ self.embedding = nn.Embedding(n_token, embedding_dim)
83
+ val_range = math.sqrt(6 / hidden_dim)
84
+ self.embedding.weight.data.uniform_(-val_range, val_range)
85
+
86
+ self.decoder_rnn_dim = hidden_dim
87
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
88
+ self.attention_layer = Attention(
89
+ self.decoder_rnn_dim,
90
+ hidden_dim,
91
+ hidden_dim,
92
+ n_location_filters,
93
+ location_kernel_size
94
+ )
95
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
96
+ self.project_to_hidden = nn.Sequential(
97
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
98
+ nn.Tanh())
99
+ self.sos = 1
100
+ self.eos = 2
101
+
102
+ def initialize_decoder_states(self, memory, mask):
103
+ """
104
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
105
+ """
106
+ B, L, H = memory.shape
107
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
108
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
109
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
110
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
111
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
112
+ self.memory = memory
113
+ self.processed_memory = self.attention_layer.memory_layer(memory)
114
+ self.mask = mask
115
+ self.unk_index = 3
116
+ self.random_mask = 0.1
117
+
118
+ def forward(self, memory, memory_mask, text_input):
119
+ """
120
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
121
+ moemory_mask.shape = (B, L, )
122
+ texts_input.shape = (B, T)
123
+ """
124
+ self.initialize_decoder_states(memory, memory_mask)
125
+ # text random mask
126
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
127
+ _text_input = text_input.clone()
128
+ _text_input.masked_fill_(random_mask, self.unk_index)
129
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
130
+ start_embedding = self.embedding(
131
+ torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
132
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
133
+
134
+ hidden_outputs, logit_outputs, alignments = [], [], []
135
+ while len(hidden_outputs) < decoder_inputs.size(0):
136
+
137
+ decoder_input = decoder_inputs[len(hidden_outputs)]
138
+ hidden, logit, attention_weights = self.decode(decoder_input)
139
+ hidden_outputs += [hidden]
140
+ logit_outputs += [logit]
141
+ alignments += [attention_weights]
142
+
143
+ hidden_outputs, logit_outputs, alignments = \
144
+ self.parse_decoder_outputs(
145
+ hidden_outputs, logit_outputs, alignments)
146
+
147
+ return hidden_outputs, logit_outputs, alignments
148
+
149
+
150
+ def decode(self, decoder_input):
151
+
152
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
153
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154
+ cell_input,
155
+ (self.decoder_hidden, self.decoder_cell))
156
+
157
+ attention_weights_cat = torch.cat(
158
+ (self.attention_weights.unsqueeze(1),
159
+ self.attention_weights_cum.unsqueeze(1)),dim=1)
160
+
161
+ self.attention_context, self.attention_weights = self.attention_layer(
162
+ self.decoder_hidden,
163
+ self.memory,
164
+ self.processed_memory,
165
+ attention_weights_cat,
166
+ self.mask)
167
+
168
+ self.attention_weights_cum += self.attention_weights
169
+
170
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171
+ hidden = self.project_to_hidden(hidden_and_context)
172
+
173
+ # dropout to increasing g
174
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
+
176
+ return hidden, logit, self.attention_weights
177
+
178
+ def parse_decoder_outputs(self, hidden, logit, alignments):
179
+
180
+ # -> [B, T_out + 1, max_time]
181
+ alignments = torch.stack(alignments).transpose(0,1)
182
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
183
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
184
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
185
+
186
+ return hidden, logit, alignments
chiluka/pretrained/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
chiluka/pretrained/JDC/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
chiluka/pretrained/JDC/__pycache__/model.cpython-310.pyc ADDED
Binary file (4.78 kB). View file
 
chiluka/pretrained/JDC/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
chiluka/pretrained/JDC/model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+ class JDCNet(nn.Module):
11
+ """
12
+ Joint Detection and Classification Network model for singing voice melody.
13
+ """
14
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
15
+ super().__init__()
16
+ self.num_class = num_class
17
+
18
+ # input = (b, 1, 31, 513), b = batch size
19
+ self.conv_block = nn.Sequential(
20
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
21
+ nn.BatchNorm2d(num_features=64),
22
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
23
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
24
+ )
25
+
26
+ # res blocks
27
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
28
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
29
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
30
+
31
+ # pool block
32
+ self.pool_block = nn.Sequential(
33
+ nn.BatchNorm2d(num_features=256),
34
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
35
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
36
+ nn.Dropout(p=0.2),
37
+ )
38
+
39
+ # maxpool layers (for auxiliary network inputs)
40
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
41
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
42
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
43
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
44
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
45
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
46
+
47
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
48
+ self.detector_conv = nn.Sequential(
49
+ nn.Conv2d(640, 256, 1, bias=False),
50
+ nn.BatchNorm2d(256),
51
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
52
+ nn.Dropout(p=0.2),
53
+ )
54
+
55
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
56
+ self.bilstm_classifier = nn.LSTM(
57
+ input_size=512, hidden_size=256,
58
+ batch_first=True, bidirectional=True) # (b, 31, 512)
59
+
60
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
61
+ self.bilstm_detector = nn.LSTM(
62
+ input_size=512, hidden_size=256,
63
+ batch_first=True, bidirectional=True) # (b, 31, 512)
64
+
65
+ # input: (b * 31, 512)
66
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
67
+
68
+ # input: (b * 31, 512)
69
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
70
+
71
+ # initialize weights
72
+ self.apply(self.init_weights)
73
+
74
+ def get_feature_GAN(self, x):
75
+ seq_len = x.shape[-2]
76
+ x = x.float().transpose(-1, -2)
77
+
78
+ convblock_out = self.conv_block(x)
79
+
80
+ resblock1_out = self.res_block1(convblock_out)
81
+ resblock2_out = self.res_block2(resblock1_out)
82
+ resblock3_out = self.res_block3(resblock2_out)
83
+ poolblock_out = self.pool_block[0](resblock3_out)
84
+ poolblock_out = self.pool_block[1](poolblock_out)
85
+
86
+ return poolblock_out.transpose(-1, -2)
87
+
88
+ def get_feature(self, x):
89
+ seq_len = x.shape[-2]
90
+ x = x.float().transpose(-1, -2)
91
+
92
+ convblock_out = self.conv_block(x)
93
+
94
+ resblock1_out = self.res_block1(convblock_out)
95
+ resblock2_out = self.res_block2(resblock1_out)
96
+ resblock3_out = self.res_block3(resblock2_out)
97
+ poolblock_out = self.pool_block[0](resblock3_out)
98
+ poolblock_out = self.pool_block[1](poolblock_out)
99
+
100
+ return self.pool_block[2](poolblock_out)
101
+
102
+ def forward(self, x):
103
+ """
104
+ Returns:
105
+ classification_prediction, detection_prediction
106
+ sizes: (b, 31, 722), (b, 31, 2)
107
+ """
108
+ ###############################
109
+ # forward pass for classifier #
110
+ ###############################
111
+ seq_len = x.shape[-1]
112
+ x = x.float().transpose(-1, -2)
113
+
114
+ convblock_out = self.conv_block(x)
115
+
116
+ resblock1_out = self.res_block1(convblock_out)
117
+ resblock2_out = self.res_block2(resblock1_out)
118
+ resblock3_out = self.res_block3(resblock2_out)
119
+
120
+
121
+ poolblock_out = self.pool_block[0](resblock3_out)
122
+ poolblock_out = self.pool_block[1](poolblock_out)
123
+ GAN_feature = poolblock_out.transpose(-1, -2)
124
+ poolblock_out = self.pool_block[2](poolblock_out)
125
+
126
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
127
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
128
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
129
+
130
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
131
+ classifier_out = self.classifier(classifier_out)
132
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
133
+
134
+ # sizes: (b, 31, 722), (b, 31, 2)
135
+ # classifier output consists of predicted pitch classes per frame
136
+ # detector output consists of: (isvoice, notvoice) estimates per frame
137
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
+
139
+ @staticmethod
140
+ def init_weights(m):
141
+ if isinstance(m, nn.Linear):
142
+ nn.init.kaiming_uniform_(m.weight)
143
+ if m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.Conv2d):
146
+ nn.init.xavier_normal_(m.weight)
147
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
148
+ for p in m.parameters():
149
+ if p.data is None:
150
+ continue
151
+
152
+ if len(p.shape) >= 2:
153
+ nn.init.orthogonal_(p.data)
154
+ else:
155
+ nn.init.normal_(p.data)
156
+
157
+
158
+ class ResBlock(nn.Module):
159
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
160
+ super().__init__()
161
+ self.downsample = in_channels != out_channels
162
+
163
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
164
+ self.pre_conv = nn.Sequential(
165
+ nn.BatchNorm2d(num_features=in_channels),
166
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
167
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
168
+ )
169
+
170
+ # conv layers
171
+ self.conv = nn.Sequential(
172
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
173
+ kernel_size=3, padding=1, bias=False),
174
+ nn.BatchNorm2d(out_channels),
175
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
176
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
177
+ )
178
+
179
+ # 1 x 1 convolution layer to match the feature dimensions
180
+ self.conv1by1 = None
181
+ if self.downsample:
182
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
183
+
184
+ def forward(self, x):
185
+ x = self.pre_conv(x)
186
+ if self.downsample:
187
+ x = self.conv(x) + self.conv1by1(x)
188
+ else:
189
+ x = self.conv(x) + x
190
+ return x
chiluka/pretrained/PLBERT/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.75 kB). View file
 
chiluka/pretrained/PLBERT/config.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Checkpoint"
2
+ mixed_precision: "fp16"
3
+ data_folder: "wikipedia_20220301.en.processed"
4
+ batch_size: 192
5
+ save_interval: 5000
6
+ log_interval: 10
7
+ num_process: 1 # number of GPUs
8
+ num_steps: 1000000
9
+
10
+ dataset_params:
11
+ tokenizer: "transfo-xl-wt103"
12
+ token_separator: " " # token used for phoneme separator (space)
13
+ token_mask: "M" # token used for phoneme mask (M)
14
+ word_separator: 3039 # token used for word separator (<formula>)
15
+ token_maps: "token_maps.pkl" # token map path
16
+
17
+ max_mel_length: 512 # max phoneme length
18
+
19
+ word_mask_prob: 0.15 # probability to mask the entire word
20
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
21
+ replace_prob: 0.2 # probablity to replace phonemes
22
+
23
+ model_params:
24
+ vocab_size: 178
25
+ hidden_size: 768
26
+ num_attention_heads: 12
27
+ intermediate_size: 2048
28
+ max_position_embeddings: 512
29
+ num_hidden_layers: 12
30
+ dropout: 0.1
chiluka/pretrained/PLBERT/step_1000000.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
3
+ size 25185187
chiluka/pretrained/PLBERT/util.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from transformers import AlbertConfig, AlbertModel
5
+
6
+ class CustomAlbert(AlbertModel):
7
+ def forward(self, *args, **kwargs):
8
+ # Call the original forward method
9
+ outputs = super().forward(*args, **kwargs)
10
+
11
+ # Only return the last_hidden_state
12
+ return outputs.last_hidden_state
13
+
14
+
15
+ def load_plbert(log_dir):
16
+ config_path = os.path.join(log_dir, "config.yml")
17
+ plbert_config = yaml.safe_load(open(config_path))
18
+
19
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
20
+ bert = CustomAlbert(albert_base_configuration)
21
+
22
+ files = os.listdir(log_dir)
23
+ ckpts = []
24
+ for f in os.listdir(log_dir):
25
+ if f.startswith("step_"): ckpts.append(f)
26
+
27
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
28
+ iters = sorted(iters)[-1]
29
+
30
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
31
+ state_dict = checkpoint['net']
32
+ from collections import OrderedDict
33
+ new_state_dict = OrderedDict()
34
+ for k, v in state_dict.items():
35
+ name = k[7:] # remove `module.`
36
+ if name.startswith('encoder.'):
37
+ name = name[8:] # remove `encoder.`
38
+ new_state_dict[name] = v
39
+ del new_state_dict["embeddings.position_ids"]
40
+ bert.load_state_dict(new_state_dict, strict=False)
41
+
42
+ return bert
chiluka/text_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Text processing utilities for phoneme tokenization."""
2
+
3
+ _pad = "$"
4
+ _punctuation = ';:,.!?¡¿—…"«»"" '
5
+ _letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
6
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
7
+
8
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
9
+
10
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
11
+
12
+
13
+ class TextCleaner:
14
+ """Converts phoneme strings to token IDs."""
15
+
16
+ def __init__(self):
17
+ self.word_index_dictionary = _symbol_to_id
18
+
19
+ def __call__(self, text):
20
+ indexes = []
21
+ for char in text:
22
+ if char in self.word_index_dictionary:
23
+ indexes.append(self.word_index_dictionary[char])
24
+ return indexes
chiluka/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for Chiluka."""
2
+
3
+ import torch
4
+ from munch import Munch
5
+
6
+
7
+ def length_to_mask(lengths):
8
+ """Convert lengths to attention mask."""
9
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
10
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
11
+ return mask
12
+
13
+
14
+ def recursive_munch(d):
15
+ """Recursively convert dict to Munch for dot notation access."""
16
+ if isinstance(d, dict):
17
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
18
+ elif isinstance(d, list):
19
+ return [recursive_munch(v) for v in d]
20
+ else:
21
+ return d
examples/basic_synthesis.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Basic example of using Chiluka for TTS synthesis.
4
+
5
+ Usage:
6
+ python basic_synthesis.py --reference path/to/reference.wav --text "Hello world"
7
+ """
8
+
9
+ import argparse
10
+ import sys
11
+ import os
12
+
13
+ # Add parent directory to path if running from examples folder
14
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from chiluka import Chiluka
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser(description="Chiluka TTS Synthesis")
21
+ parser.add_argument("--reference", "-r", required=True, help="Path to reference audio file")
22
+ parser.add_argument("--text", "-t", default="Hello, this is Chiluka speaking!", help="Text to synthesize")
23
+ parser.add_argument("--language", "-l", default="en", help="Language code (en, te, hi, etc.)")
24
+ parser.add_argument("--output", "-o", default="output.wav", help="Output WAV file path")
25
+ parser.add_argument("--alpha", type=float, default=0.3, help="Acoustic style mixing (0-1)")
26
+ parser.add_argument("--beta", type=float, default=0.7, help="Prosodic style mixing (0-1)")
27
+ parser.add_argument("--steps", type=int, default=5, help="Diffusion steps")
28
+ args = parser.parse_args()
29
+
30
+ # Initialize - uses bundled models
31
+ print("Initializing Chiluka TTS...")
32
+ tts = Chiluka()
33
+
34
+ # Synthesize
35
+ print(f"Synthesizing: '{args.text}'")
36
+ wav = tts.synthesize(
37
+ text=args.text,
38
+ reference_audio=args.reference,
39
+ language=args.language,
40
+ alpha=args.alpha,
41
+ beta=args.beta,
42
+ diffusion_steps=args.steps,
43
+ )
44
+
45
+ # Save
46
+ tts.save_wav(wav, args.output)
47
+ print(f"Done! Output saved to: {args.output}")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
examples/telugu_synthesis.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Telugu TTS synthesis example using Chiluka.
4
+
5
+ Usage:
6
+ python telugu_synthesis.py --reference path/to/telugu_reference.wav
7
+ """
8
+
9
+ import argparse
10
+ import sys
11
+ import os
12
+
13
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
+
15
+ from chiluka import Chiluka
16
+
17
+
18
+ def main():
19
+ parser = argparse.ArgumentParser(description="Chiluka Telugu TTS")
20
+ parser.add_argument("--reference", "-r", required=True, help="Path to Telugu reference audio")
21
+ parser.add_argument("--output", "-o", default="telugu_output.wav", help="Output file")
22
+ args = parser.parse_args()
23
+
24
+ # Sample Telugu texts
25
+ texts = [
26
+ "నమస్కారం, నేను చిలుక మాట్లాడుతున్నాను",
27
+ "మహారాజా తమరిని మోసగించి నేను ఎక్కడికి పారిపోగలను",
28
+ "మీకు ధన్యవాదాలు",
29
+ ]
30
+
31
+ # Initialize
32
+ print("Initializing Chiluka TTS...")
33
+ tts = Chiluka()
34
+
35
+ # Synthesize each text
36
+ for i, text in enumerate(texts):
37
+ print(f"\nSynthesizing ({i+1}/{len(texts)}): {text}")
38
+ wav = tts.synthesize(
39
+ text=text,
40
+ reference_audio=args.reference,
41
+ language="te",
42
+ alpha=0.3,
43
+ beta=0.7,
44
+ )
45
+
46
+ output_path = args.output.replace(".wav", f"_{i+1}.wav")
47
+ tts.save_wav(wav, output_path)
48
+
49
+ print("\nDone!")
50
+
51
+
52
+ if __name__ == "__main__":
53
+ main()
pyproject.toml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel", "setuptools-scm[toml]>=6.2"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "chiluka"
7
+ version = "0.1.0"
8
+ description = "Chiluka - A lightweight TTS inference package based on StyleTTS2"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.8"
12
+ authors = [
13
+ {name = "Your Name", email = "your.email@example.com"}
14
+ ]
15
+ keywords = ["tts", "text-to-speech", "speech-synthesis", "styletts2", "deep-learning"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Developers",
19
+ "Intended Audience :: Science/Research",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Operating System :: OS Independent",
22
+ "Programming Language :: Python :: 3",
23
+ "Programming Language :: Python :: 3.8",
24
+ "Programming Language :: Python :: 3.9",
25
+ "Programming Language :: Python :: 3.10",
26
+ "Programming Language :: Python :: 3.11",
27
+ "Topic :: Multimedia :: Sound/Audio :: Speech",
28
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
29
+ ]
30
+ dependencies = [
31
+ "torch>=1.13.0",
32
+ "torchaudio>=0.13.0",
33
+ "transformers>=4.20.0",
34
+ "librosa>=0.9.0",
35
+ "phonemizer>=3.0.0",
36
+ "nltk>=3.7",
37
+ "PyYAML>=6.0",
38
+ "munch>=2.5.0",
39
+ "einops>=0.6.0",
40
+ "einops-exts>=0.0.4",
41
+ "numpy>=1.21.0",
42
+ "scipy>=1.7.0",
43
+ ]
44
+
45
+ [project.optional-dependencies]
46
+ playback = ["pyaudio>=0.2.11"]
47
+ dev = ["pytest>=7.0.0", "black>=22.0.0", "isort>=5.10.0"]
48
+
49
+ [project.urls]
50
+ Homepage = "https://github.com/yourusername/chiluka"
51
+ Documentation = "https://github.com/yourusername/chiluka#readme"
52
+ Repository = "https://github.com/yourusername/chiluka"
53
+ Issues = "https://github.com/yourusername/chiluka/issues"
54
+
55
+ [tool.setuptools.packages.find]
56
+ where = ["."]
57
+
58
+ [tool.black]
59
+ line-length = 120
60
+ target-version = ['py38', 'py39', 'py310', 'py311']
61
+
62
+ [tool.isort]
63
+ profile = "black"
64
+ line_length = 120
setup.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Setup script for Chiluka TTS package."""
2
+
3
+ from setuptools import setup, find_packages
4
+
5
+ with open("README.md", "r", encoding="utf-8") as fh:
6
+ long_description = fh.read()
7
+
8
+ setup(
9
+ name="chiluka",
10
+ version="0.1.0",
11
+ author="Your Name",
12
+ author_email="your.email@example.com",
13
+ description="Chiluka - A lightweight TTS inference package based on StyleTTS2",
14
+ long_description=long_description,
15
+ long_description_content_type="text/markdown",
16
+ url="https://github.com/yourusername/chiluka",
17
+ packages=find_packages(),
18
+ classifiers=[
19
+ "Development Status :: 3 - Alpha",
20
+ "Intended Audience :: Developers",
21
+ "Intended Audience :: Science/Research",
22
+ "License :: OSI Approved :: MIT License",
23
+ "Operating System :: OS Independent",
24
+ "Programming Language :: Python :: 3",
25
+ "Programming Language :: Python :: 3.8",
26
+ "Programming Language :: Python :: 3.9",
27
+ "Programming Language :: Python :: 3.10",
28
+ "Programming Language :: Python :: 3.11",
29
+ "Topic :: Multimedia :: Sound/Audio :: Speech",
30
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
31
+ ],
32
+ python_requires=">=3.8",
33
+ install_requires=[
34
+ "torch>=1.13.0",
35
+ "torchaudio>=0.13.0",
36
+ "transformers>=4.20.0",
37
+ "librosa>=0.9.0",
38
+ "phonemizer>=3.0.0",
39
+ "nltk>=3.7",
40
+ "PyYAML>=6.0",
41
+ "munch>=2.5.0",
42
+ "einops>=0.6.0",
43
+ "einops-exts>=0.0.4",
44
+ "numpy>=1.21.0",
45
+ "scipy>=1.7.0",
46
+ ],
47
+ extras_require={
48
+ "playback": ["pyaudio>=0.2.11"],
49
+ "dev": [
50
+ "pytest>=7.0.0",
51
+ "black>=22.0.0",
52
+ "isort>=5.10.0",
53
+ ],
54
+ },
55
+ entry_points={
56
+ "console_scripts": [
57
+ "chiluka=chiluka.cli:main",
58
+ ],
59
+ },
60
+ )