thecodeworm commited on
Commit
5ebf545
·
verified ·
1 Parent(s): 4756c56

Create enhancement_model/model.py

Browse files
Files changed (1) hide show
  1. enhancement_model/model.py +188 -0
enhancement_model/model.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ U-Net Autoencoder for Speech Enhancement
3
+ Processes mel-spectrograms to remove noise
4
+
5
+ Architecture: Encoder-Decoder with skip connections
6
+ Input: Noisy mel-spectrogram (128 x T)
7
+ Output: Clean mel-spectrogram (128 x T)
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ class DoubleConv(nn.Module):
16
+ """
17
+ Double Convolution block: Conv -> BatchNorm -> ReLU -> Conv -> BatchNorm -> ReLU
18
+ Used as basic building block in U-Net
19
+ """
20
+ def __init__(self, in_channels, out_channels):
21
+ super().__init__()
22
+ self.double_conv = nn.Sequential(
23
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
24
+ nn.BatchNorm2d(out_channels),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
27
+ nn.BatchNorm2d(out_channels),
28
+ nn.ReLU(inplace=True)
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.double_conv(x)
33
+
34
+
35
+ class Down(nn.Module):
36
+ """
37
+ Downsampling block: MaxPool -> DoubleConv
38
+ Reduces spatial dimensions, increases channels
39
+ """
40
+ def __init__(self, in_channels, out_channels):
41
+ super().__init__()
42
+ self.maxpool_conv = nn.Sequential(
43
+ nn.MaxPool2d(2),
44
+ DoubleConv(in_channels, out_channels)
45
+ )
46
+
47
+ def forward(self, x):
48
+ return self.maxpool_conv(x)
49
+
50
+
51
+ class Up(nn.Module):
52
+ """
53
+ Upsampling block: Upsample -> Concat with skip connection -> DoubleConv
54
+ Increases spatial dimensions, decreases channels
55
+ """
56
+ def __init__(self, in_channels, out_channels):
57
+ super().__init__()
58
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
59
+ self.conv = DoubleConv(in_channels, out_channels)
60
+
61
+ def forward(self, x1, x2):
62
+ """
63
+ Args:
64
+ x1: Feature map from decoder path
65
+ x2: Feature map from encoder path (skip connection)
66
+ """
67
+ x1 = self.up(x1)
68
+
69
+ # Handle size mismatch due to padding
70
+ diffY = x2.size()[2] - x1.size()[2]
71
+ diffX = x2.size()[3] - x1.size()[3]
72
+
73
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
74
+ diffY // 2, diffY - diffY // 2])
75
+
76
+ # Concatenate skip connection
77
+ x = torch.cat([x2, x1], dim=1)
78
+ return self.conv(x)
79
+
80
+
81
+ class UNetAudioEnhancer(nn.Module):
82
+ """
83
+ U-Net model for audio enhancement
84
+
85
+ Architecture:
86
+ Encoder: 4 downsampling stages (64 -> 128 -> 256 -> 512)
87
+ Bottleneck: 1024 channels
88
+ Decoder: 4 upsampling stages (512 -> 256 -> 128 -> 64)
89
+ Output: 1 channel (clean spectrogram)
90
+
91
+ Args:
92
+ in_channels: Number of input channels (1 for single spectrogram)
93
+ out_channels: Number of output channels (1 for single spectrogram)
94
+ """
95
+ def __init__(self, in_channels=1, out_channels=1):
96
+ super().__init__()
97
+
98
+ # Initial convolution
99
+ self.inc = DoubleConv(in_channels, 64)
100
+
101
+ # Encoder (downsampling path)
102
+ self.down1 = Down(64, 128)
103
+ self.down2 = Down(128, 256)
104
+ self.down3 = Down(256, 512)
105
+ self.down4 = Down(512, 1024)
106
+
107
+ # Decoder (upsampling path)
108
+ self.up1 = Up(1024, 512)
109
+ self.up2 = Up(512, 256)
110
+ self.up3 = Up(256, 128)
111
+ self.up4 = Up(128, 64)
112
+
113
+ # Output convolution
114
+ self.outc = nn.Conv2d(64, out_channels, kernel_size=1)
115
+
116
+ def forward(self, x):
117
+ """
118
+ Forward pass through U-Net
119
+
120
+ Args:
121
+ x: Input tensor (batch_size, 1, height, width)
122
+ For mel-spectrograms: (B, 1, 128, T)
123
+
124
+ Returns:
125
+ Output tensor (batch_size, 1, height, width)
126
+ """
127
+ # Encoder path (save features for skip connections)
128
+ x1 = self.inc(x) # 64 channels
129
+ x2 = self.down1(x1) # 128 channels
130
+ x3 = self.down2(x2) # 256 channels
131
+ x4 = self.down3(x3) # 512 channels
132
+ x5 = self.down4(x4) # 1024 channels (bottleneck)
133
+
134
+ # Decoder path (with skip connections)
135
+ x = self.up1(x5, x4) # 512 channels
136
+ x = self.up2(x, x3) # 256 channels
137
+ x = self.up3(x, x2) # 128 channels
138
+ x = self.up4(x, x1) # 64 channels
139
+
140
+ # Output
141
+ x = self.outc(x) # 1 channel
142
+ return x
143
+
144
+ def count_parameters(self):
145
+ """Count trainable parameters"""
146
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
147
+
148
+
149
+ def test_model():
150
+ """
151
+ Test the model with dummy input
152
+ Verifies input/output dimensions
153
+ """
154
+ print("="*70)
155
+ print("Testing U-Net Audio Enhancer Model")
156
+ print("="*70)
157
+
158
+ # Create model
159
+ model = UNetAudioEnhancer(in_channels=1, out_channels=1)
160
+
161
+ # Print model info
162
+ print(f"\nModel Parameters: {model.count_parameters():,}")
163
+ print(f" (~{model.count_parameters() / 1e6:.2f}M parameters)")
164
+
165
+ # Test with dummy input
166
+ # Mel-spectrogram size: (batch, channels, mels, time)
167
+ # Time frames = (3 seconds * 16000 Hz) / 256 hop_length = 187.5 ≈ 188
168
+ batch_size = 4
169
+ mel_bins = 128
170
+ time_frames = 188
171
+
172
+ dummy_input = torch.randn(batch_size, 1, mel_bins, time_frames)
173
+ print(f"\n🔍 Input shape: {dummy_input.shape}")
174
+
175
+ # Forward pass
176
+ with torch.no_grad():
177
+ output = model(dummy_input)
178
+
179
+ print(f"Output shape: {output.shape}")
180
+
181
+ # Verify shapes match
182
+ assert output.shape == dummy_input.shape, "Output shape mismatch!"
183
+ print("\nModel test passed!")
184
+ print("="*70)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ test_model()