import torch import torch.nn as nn import torchaudio from transformers import AutoProcessor, MusicgenForConditionalGeneration import os import matplotlib.pyplot as plt # Paths and configurations pretrained_model_name = "facebook/musicgen-medium" # Pre-trained MusicGen model name model_save_path = "./ModelsFinetuned/Hindustani_Adapter/ISMIR/MusicgenMedium_with_adapters_EncoderDecoder_1024_Preceptron.pt" # Path to the fine-tuned model output_audio_path = "./GeneratedAudios/New/HC/1.wav" # Path to save the generated audio AudioWaveform_graph_path = "./GeneratedGraphs/New/HC/1.jpeg" # Path to save the plot of generated audio dropout_prob = 0.0 # Dropout probability in Adapter Layers sample_rate = 32000 # Desired sample rate for the output audio adapter_bottleneck_dim = 1024 # Use the same dimension as training max_new_tokens = 1024 # To control length of music piece generated 512 = 10 sec device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Configuration to choose between pre-trained and fine-tuned model use_finetuned_model = False # Set to True to use the fine-tuned model, False to use pre-trained model # MultiLayer Perceptron class Adapter(nn.Module): def __init__(self, bottleneck_channels=256, input_channels=2, seq_len=32000): # input_channels=2 super(Adapter, self).__init__() self.adapter_down = nn.Linear(seq_len, bottleneck_channels) self.activation = nn.GELU() self.adapter_up = nn.Linear(bottleneck_channels, seq_len) self.dropout = nn.Dropout(p=dropout_prob) def forward(self, residual): x = self.adapter_down(residual.squeeze(1)) x = self.activation(x) x = self.adapter_up(x) x = self.dropout(x + residual.squeeze(1)) return x.unsqueeze(1).expand(-1, 2, -1) # Expanding to 2 channels """ class Adapter(nn.Module): def __init__(self, bottleneck_channels=32, input_channels=2, seq_len=32000): super(ConvAdapter, self).__init__() # Down-projection: Reduce dimensionality self.adapter_down = nn.Sequential( nn.Conv1d( in_channels=input_channels, out_channels=bottleneck_channels, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm1d(bottleneck_channels), nn.GELU() ) # Bottleneck: Deeper feature extraction with residual connections self.bottleneck = nn.Sequential( ResidualBlock(bottleneck_channels, bottleneck_channels), nn.Conv1d( in_channels=bottleneck_channels, out_channels=bottleneck_channels, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm1d(bottleneck_channels), nn.GELU() ) # Up-projection: Restore original dimensionality self.adapter_up = nn.Sequential( nn.Conv1d( in_channels=bottleneck_channels, out_channels=input_channels, kernel_size=3, stride=1, padding=1 ), nn.BatchNorm1d(input_channels) ) # Dropout for regularization self.dropout = nn.Dropout(p=dropout_prob) def forward(self, residual): # Apply down-projection x = self.adapter_down(residual) # Apply bottleneck processing x = self.bottleneck(x) # Apply up-projection x = self.adapter_up(x) # Add residual connection and dropout x = self.dropout(x + residual) return x class ResidualBlock(nn.Module): # A simple residual block for feature extraction in the bottleneck layer. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding) self.bn1 = nn.BatchNorm1d(out_channels) self.activation = nn.GELU() self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride, padding) self.bn2 = nn.BatchNorm1d(out_channels) def forward(self, x): # Residual connection residual = x x = self.conv1(x) x = self.bn1(x) x = self.activation(x) x = self.conv2(x) x = self.bn2(x) return x + residual """ # MusicGen Model with Adapter (same as in training) class MusicGenWithAdapters(nn.Module): def __init__(self, musicgen_model, processor, adapter_bottleneck_dim=32, device='cpu'): super(MusicGenWithAdapters, self).__init__() self.musicgen = musicgen_model self.adapter = Adapter(bottleneck_channels=adapter_bottleneck_dim, input_channels=1, seq_len=32000).to(device) def forward(self, audio_text): encoder_output = self.musicgen.generate(**audio_text, max_new_tokens=max_new_tokens) encoder_output = encoder_output.to('cpu') encoder_output = torchaudio.transforms.Resample(orig_freq=encoder_output.size(2), new_freq=32000)(encoder_output) encoder_output = encoder_output.to(self.adapter.adapter_down.weight.device) adapted = self.adapter(encoder_output) return adapted # Function to load the model based on the configuration def load_model(use_finetuned_model, model_save_path, device): if use_finetuned_model: # Load the fine-tuned model (MusicGen + Adapters) processor = AutoProcessor.from_pretrained(pretrained_model_name) musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device) model_with_adapters = MusicGenWithAdapters(musicgen_model, processor, adapter_bottleneck_dim=adapter_bottleneck_dim, device=device).to(device) # Load the state dicts for both the MusicGen model and the adapter checkpoint = torch.load(model_save_path, map_location=device) model_with_adapters.musicgen.load_state_dict(checkpoint['musicgen_state_dict']) model_with_adapters.adapter.load_state_dict(checkpoint['adapter_state_dict']) model_with_adapters.eval() total_params = sum(p.numel() for p in model_with_adapters.parameters()) print(f"Total number of parameters in the fine-tuned model: {total_params}") return model_with_adapters else: # Load the pre-trained MusicGen model musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device) musicgen_model.eval() total_params = sum(p.numel() for p in musicgen_model.parameters()) print(f"Total number of parameters in the Original model: {total_params}") return musicgen_model # Function to generate audio from a text prompt def generate_audio(model, text_prompt, sample_rate=32000): processor = AutoProcessor.from_pretrained(pretrained_model_name) # Generate input tensor for the text prompt input_data = processor(text=[text_prompt], return_tensors="pt").to(device) # Generate audio using the fine-tuned or pre-trained model if isinstance(model, MusicGenWithAdapters): musicgen = model.musicgen else: musicgen = model with torch.no_grad(): generated_output = musicgen.generate(**input_data, max_new_tokens=max_new_tokens) waveform = generated_output.squeeze(0).cpu() if sample_rate != 32000: resampler = torchaudio.transforms.Resample(orig_freq=32000, new_freq=sample_rate) waveform = resampler(waveform) return waveform # Main inference code if __name__ == "__main__": # Get text prompt from the user text_prompt = input("Enter a text prompt for music generation: ") # Load the appropriate model (pre-trained or fine-tuned) based on the setting model = load_model(use_finetuned_model, model_save_path, device) # Generate audio waveform = generate_audio( model, text_prompt, sample_rate=sample_rate ) # Save the generated audio torchaudio.save(output_audio_path, waveform, sample_rate) print(f"Generated audio saved at {output_audio_path}") # Optional: Visualize the waveform plt.figure(figsize=(12, 4)) plt.plot(waveform.t().numpy()) plt.savefig(AudioWaveform_graph_path) plt.show() print(f"Waveform graph saved at {AudioWaveform_graph_path}")