| | import torch |
| | import torch.nn as nn |
| | import torchaudio |
| | from transformers import AutoProcessor, MusicgenForConditionalGeneration |
| | import os |
| | import matplotlib.pyplot as plt |
| |
|
| | |
| | pretrained_model_name = "facebook/musicgen-medium" |
| | model_save_path = "./ModelsFinetuned/Hindustani_Adapter/ISMIR/MusicgenMedium_with_adapters_EncoderDecoder_1024_Preceptron.pt" |
| | output_audio_path = "./GeneratedAudios/New/HC/1.wav" |
| | AudioWaveform_graph_path = "./GeneratedGraphs/New/HC/1.jpeg" |
| | dropout_prob = 0.0 |
| | sample_rate = 32000 |
| | adapter_bottleneck_dim = 1024 |
| | max_new_tokens = 1024 |
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| |
|
| | |
| | use_finetuned_model = False |
| |
|
| |
|
| | |
| | class Adapter(nn.Module): |
| | def __init__(self, bottleneck_channels=256, input_channels=2, seq_len=32000): |
| | super(Adapter, self).__init__() |
| |
|
| | self.adapter_down = nn.Linear(seq_len, bottleneck_channels) |
| | self.activation = nn.GELU() |
| | self.adapter_up = nn.Linear(bottleneck_channels, seq_len) |
| | self.dropout = nn.Dropout(p=dropout_prob) |
| |
|
| | def forward(self, residual): |
| | x = self.adapter_down(residual.squeeze(1)) |
| | x = self.activation(x) |
| | x = self.adapter_up(x) |
| | x = self.dropout(x + residual.squeeze(1)) |
| |
|
| | return x.unsqueeze(1).expand(-1, 2, -1) |
| |
|
| | """ |
| | |
| | class Adapter(nn.Module): |
| | def __init__(self, bottleneck_channels=32, input_channels=2, seq_len=32000): |
| | super(ConvAdapter, self).__init__() |
| | |
| | # Down-projection: Reduce dimensionality |
| | self.adapter_down = nn.Sequential( |
| | nn.Conv1d( |
| | in_channels=input_channels, |
| | out_channels=bottleneck_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1 |
| | ), |
| | nn.BatchNorm1d(bottleneck_channels), |
| | nn.GELU() |
| | ) |
| | |
| | # Bottleneck: Deeper feature extraction with residual connections |
| | self.bottleneck = nn.Sequential( |
| | ResidualBlock(bottleneck_channels, bottleneck_channels), |
| | nn.Conv1d( |
| | in_channels=bottleneck_channels, |
| | out_channels=bottleneck_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1 |
| | ), |
| | nn.BatchNorm1d(bottleneck_channels), |
| | nn.GELU() |
| | ) |
| | |
| | # Up-projection: Restore original dimensionality |
| | self.adapter_up = nn.Sequential( |
| | nn.Conv1d( |
| | in_channels=bottleneck_channels, |
| | out_channels=input_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1 |
| | ), |
| | nn.BatchNorm1d(input_channels) |
| | ) |
| | |
| | # Dropout for regularization |
| | self.dropout = nn.Dropout(p=dropout_prob) |
| | |
| | def forward(self, residual): |
| | # Apply down-projection |
| | x = self.adapter_down(residual) |
| | |
| | # Apply bottleneck processing |
| | x = self.bottleneck(x) |
| | |
| | # Apply up-projection |
| | x = self.adapter_up(x) |
| | |
| | # Add residual connection and dropout |
| | x = self.dropout(x + residual) |
| | |
| | return x |
| | |
| | |
| | class ResidualBlock(nn.Module): |
| | # A simple residual block for feature extraction in the bottleneck layer. |
| | |
| | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): |
| | super(ResidualBlock, self).__init__() |
| | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding) |
| | self.bn1 = nn.BatchNorm1d(out_channels) |
| | self.activation = nn.GELU() |
| | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding) |
| | self.bn2 = nn.BatchNorm1d(out_channels) |
| | |
| | def forward(self, x): |
| | # Residual connection |
| | residual = x |
| | x = self.conv1(x) |
| | x = self.bn1(x) |
| | x = self.activation(x) |
| | x = self.conv2(x) |
| | x = self.bn2(x) |
| | return x + residual |
| | """ |
| |
|
| | |
| | class MusicGenWithAdapters(nn.Module): |
| | def __init__(self, musicgen_model, processor, adapter_bottleneck_dim=32, device='cpu'): |
| | super(MusicGenWithAdapters, self).__init__() |
| | self.musicgen = musicgen_model |
| | self.adapter = Adapter(bottleneck_channels=adapter_bottleneck_dim, input_channels=1, seq_len=32000).to(device) |
| |
|
| | def forward(self, audio_text): |
| | encoder_output = self.musicgen.generate(**audio_text, max_new_tokens=max_new_tokens) |
| | encoder_output = encoder_output.to('cpu') |
| | encoder_output = torchaudio.transforms.Resample(orig_freq=encoder_output.size(2), new_freq=32000)(encoder_output) |
| | encoder_output = encoder_output.to(self.adapter.adapter_down.weight.device) |
| | adapted = self.adapter(encoder_output) |
| | return adapted |
| |
|
| | |
| | def load_model(use_finetuned_model, model_save_path, device): |
| | if use_finetuned_model: |
| | |
| | processor = AutoProcessor.from_pretrained(pretrained_model_name) |
| | musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device) |
| | model_with_adapters = MusicGenWithAdapters(musicgen_model, processor, adapter_bottleneck_dim=adapter_bottleneck_dim, device=device).to(device) |
| | |
| | |
| | checkpoint = torch.load(model_save_path, map_location=device) |
| | model_with_adapters.musicgen.load_state_dict(checkpoint['musicgen_state_dict']) |
| | model_with_adapters.adapter.load_state_dict(checkpoint['adapter_state_dict']) |
| |
|
| | model_with_adapters.eval() |
| | total_params = sum(p.numel() for p in model_with_adapters.parameters()) |
| | print(f"Total number of parameters in the fine-tuned model: {total_params}") |
| | return model_with_adapters |
| | else: |
| | |
| | musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device) |
| | musicgen_model.eval() |
| | total_params = sum(p.numel() for p in musicgen_model.parameters()) |
| | print(f"Total number of parameters in the Original model: {total_params}") |
| | return musicgen_model |
| |
|
| | |
| | def generate_audio(model, text_prompt, sample_rate=32000): |
| | processor = AutoProcessor.from_pretrained(pretrained_model_name) |
| |
|
| | |
| | input_data = processor(text=[text_prompt], return_tensors="pt").to(device) |
| |
|
| | |
| | if isinstance(model, MusicGenWithAdapters): |
| | musicgen = model.musicgen |
| | else: |
| | musicgen = model |
| |
|
| | with torch.no_grad(): |
| | generated_output = musicgen.generate(**input_data, max_new_tokens=max_new_tokens) |
| |
|
| | waveform = generated_output.squeeze(0).cpu() |
| |
|
| | if sample_rate != 32000: |
| | resampler = torchaudio.transforms.Resample(orig_freq=32000, new_freq=sample_rate) |
| | waveform = resampler(waveform) |
| |
|
| | return waveform |
| |
|
| | |
| | if __name__ == "__main__": |
| | |
| | text_prompt = input("Enter a text prompt for music generation: ") |
| |
|
| | |
| | model = load_model(use_finetuned_model, model_save_path, device) |
| |
|
| | |
| | waveform = generate_audio( |
| | model, |
| | text_prompt, |
| | sample_rate=sample_rate |
| | ) |
| |
|
| | |
| | torchaudio.save(output_audio_path, waveform, sample_rate) |
| | print(f"Generated audio saved at {output_audio_path}") |
| |
|
| | |
| | plt.figure(figsize=(12, 4)) |
| | plt.plot(waveform.t().numpy()) |
| | plt.savefig(AudioWaveform_graph_path) |
| | plt.show() |
| | print(f"Waveform graph saved at {AudioWaveform_graph_path}") |
| |
|