File size: 8,521 Bytes
7970e88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import torch
import torch.nn as nn
import torchaudio
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import os
import matplotlib.pyplot as plt

# Paths and configurations
pretrained_model_name = "facebook/musicgen-medium"  # Pre-trained MusicGen model name
model_save_path = "./ModelsFinetuned/Hindustani_Adapter/ISMIR/MusicgenMedium_with_adapters_EncoderDecoder_1024_Preceptron.pt"  # Path to the fine-tuned model
output_audio_path = "./GeneratedAudios/New/HC/1.wav"  # Path to save the generated audio
AudioWaveform_graph_path = "./GeneratedGraphs/New/HC/1.jpeg"  # Path to save the plot of generated audio
dropout_prob = 0.0  # Dropout probability in Adapter Layers
sample_rate = 32000  # Desired sample rate for the output audio
adapter_bottleneck_dim = 1024  # Use the same dimension as training
max_new_tokens = 1024 # To control length of music piece generated 512 = 10 sec
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Configuration to choose between pre-trained and fine-tuned model
use_finetuned_model = False  # Set to True to use the fine-tuned model, False to use pre-trained model


# MultiLayer Perceptron
class Adapter(nn.Module):
    def __init__(self, bottleneck_channels=256, input_channels=2, seq_len=32000):  # input_channels=2
        super(Adapter, self).__init__()

        self.adapter_down = nn.Linear(seq_len, bottleneck_channels)
        self.activation = nn.GELU()
        self.adapter_up = nn.Linear(bottleneck_channels, seq_len)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, residual):
        x = self.adapter_down(residual.squeeze(1))  
        x = self.activation(x)
        x = self.adapter_up(x)
        x = self.dropout(x + residual.squeeze(1))

        return x.unsqueeze(1).expand(-1, 2, -1)  # Expanding to 2 channels

"""

class Adapter(nn.Module):
    def __init__(self, bottleneck_channels=32, input_channels=2, seq_len=32000):
        super(ConvAdapter, self).__init__()
        
        # Down-projection: Reduce dimensionality
        self.adapter_down = nn.Sequential(
            nn.Conv1d(
                in_channels=input_channels, 
                out_channels=bottleneck_channels, 
                kernel_size=3, 
                stride=1, 
                padding=1
            ),
            nn.BatchNorm1d(bottleneck_channels),
            nn.GELU()
        )
        
        # Bottleneck: Deeper feature extraction with residual connections
        self.bottleneck = nn.Sequential(
            ResidualBlock(bottleneck_channels, bottleneck_channels),
            nn.Conv1d(
                in_channels=bottleneck_channels, 
                out_channels=bottleneck_channels, 
                kernel_size=3, 
                stride=1, 
                padding=1
            ),
            nn.BatchNorm1d(bottleneck_channels),
            nn.GELU()
        )
        
        # Up-projection: Restore original dimensionality
        self.adapter_up = nn.Sequential(
            nn.Conv1d(
                in_channels=bottleneck_channels, 
                out_channels=input_channels, 
                kernel_size=3, 
                stride=1, 
                padding=1
            ),
            nn.BatchNorm1d(input_channels)
        )
        
        # Dropout for regularization
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, residual):
        # Apply down-projection
        x = self.adapter_down(residual)
        
        # Apply bottleneck processing
        x = self.bottleneck(x)
        
        # Apply up-projection
        x = self.adapter_up(x)
        
        # Add residual connection and dropout
        x = self.dropout(x + residual)
        
        return x


class ResidualBlock(nn.Module):
    # A simple residual block for feature extraction in the bottleneck layer.

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.activation = nn.GELU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding)
        self.bn2 = nn.BatchNorm1d(out_channels)

    def forward(self, x):
        # Residual connection
        residual = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x + residual
"""

# MusicGen Model with Adapter (same as in training)
class MusicGenWithAdapters(nn.Module):
    def __init__(self, musicgen_model, processor, adapter_bottleneck_dim=32, device='cpu'):
        super(MusicGenWithAdapters, self).__init__()
        self.musicgen = musicgen_model
        self.adapter = Adapter(bottleneck_channels=adapter_bottleneck_dim, input_channels=1, seq_len=32000).to(device)

    def forward(self, audio_text):
        encoder_output = self.musicgen.generate(**audio_text, max_new_tokens=max_new_tokens)
        encoder_output = encoder_output.to('cpu')
        encoder_output = torchaudio.transforms.Resample(orig_freq=encoder_output.size(2), new_freq=32000)(encoder_output)
        encoder_output = encoder_output.to(self.adapter.adapter_down.weight.device)
        adapted = self.adapter(encoder_output)
        return adapted

# Function to load the model based on the configuration
def load_model(use_finetuned_model, model_save_path, device):
    if use_finetuned_model:
        # Load the fine-tuned model (MusicGen + Adapters)
        processor = AutoProcessor.from_pretrained(pretrained_model_name)
        musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device)
        model_with_adapters = MusicGenWithAdapters(musicgen_model, processor, adapter_bottleneck_dim=adapter_bottleneck_dim, device=device).to(device)
        
        # Load the state dicts for both the MusicGen model and the adapter
        checkpoint = torch.load(model_save_path, map_location=device)
        model_with_adapters.musicgen.load_state_dict(checkpoint['musicgen_state_dict'])
        model_with_adapters.adapter.load_state_dict(checkpoint['adapter_state_dict'])

        model_with_adapters.eval()
        total_params = sum(p.numel() for p in model_with_adapters.parameters())
        print(f"Total number of parameters in the fine-tuned model: {total_params}")
        return model_with_adapters
    else:
        # Load the pre-trained MusicGen model
        musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device)
        musicgen_model.eval()
        total_params = sum(p.numel() for p in musicgen_model.parameters())
        print(f"Total number of parameters in the Original model: {total_params}")
        return musicgen_model

# Function to generate audio from a text prompt
def generate_audio(model, text_prompt, sample_rate=32000):
    processor = AutoProcessor.from_pretrained(pretrained_model_name)

    # Generate input tensor for the text prompt
    input_data = processor(text=[text_prompt], return_tensors="pt").to(device)

    # Generate audio using the fine-tuned or pre-trained model
    if isinstance(model, MusicGenWithAdapters):
        musicgen = model.musicgen
    else:
        musicgen = model

    with torch.no_grad():
        generated_output = musicgen.generate(**input_data, max_new_tokens=max_new_tokens)

    waveform = generated_output.squeeze(0).cpu()

    if sample_rate != 32000:
        resampler = torchaudio.transforms.Resample(orig_freq=32000, new_freq=sample_rate)
        waveform = resampler(waveform)

    return waveform

# Main inference code
if __name__ == "__main__":
    # Get text prompt from the user
    text_prompt = input("Enter a text prompt for music generation: ")

    # Load the appropriate model (pre-trained or fine-tuned) based on the setting
    model = load_model(use_finetuned_model, model_save_path, device)

    # Generate audio
    waveform = generate_audio(
        model,
        text_prompt,
        sample_rate=sample_rate
    )

    # Save the generated audio
    torchaudio.save(output_audio_path, waveform, sample_rate)
    print(f"Generated audio saved at {output_audio_path}")

    # Optional: Visualize the waveform
    plt.figure(figsize=(12, 4))
    plt.plot(waveform.t().numpy())
    plt.savefig(AudioWaveform_graph_path)
    plt.show()
    print(f"Waveform graph saved at {AudioWaveform_graph_path}")