Adapters / ISMIR /Adapters_40M /GenerateAudios.py
0hawkeye33's picture
Create GenerateAudios.py
cc2cb10 verified
import torch
import torch.nn as nn
import torchaudio
from transformers import AutoProcessor, MusicgenForConditionalGeneration
import os
import matplotlib.pyplot as plt
import pandas as pd
# Paths and configurations
pretrained_model_name = "facebook/musicgen-medium" # Pre-trained MusicGen model name
model_save_path = r"/home/shivam.chauhan/.cache/huggingface/hub/models--0hawkeye33--Adapters/blobs/95d3c77dc73bd989622740c3ed49c13af31fccae5a1b507ca75fda0fa5aba091" # Path to the fine-tuned model
sample_rate = 32000 # Desired sample rate for the output audio
adapter_bottleneck_dim = 512 # Use the same dimension as training, use 612 for linear
max_new_tokens = 1024 # To control length of music piece generated 512 = 10 sec
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Configuration to choose between pre-trained and fine-tuned model
use_finetuned_model = True # Set to True to use the fine-tuned model, False to use pre-trained model
"""
#################################### Linear Adapter class Begins here ############################
class Adapter(nn.Module):
def __init__(self, bottleneck_channels=256, input_channels=1, seq_len=32000):
super(Adapter, self).__init__()
self.adapter_down = nn.Linear(seq_len, bottleneck_channels)
self.activation = nn.ReLU()
self.adapter_up = nn.Linear(bottleneck_channels, seq_len)
self.dropout = nn.Dropout(p=0.0)
def forward(self, residual):
x = self.adapter_down(residual.squeeze(1))
x = self.activation(x)
x = self.adapter_up(x)
x = self.dropout(x + residual.squeeze(1))
return x.unsqueeze(1)
# MusicGen Model with Adapter (same as in training)
class MusicGenWithAdapters(nn.Module):
def __init__(self, musicgen_model, processor, adapter_bottleneck_dim=256, device='cpu'):
super(MusicGenWithAdapters, self).__init__()
self.musicgen = musicgen_model
self.adapter = Adapter(bottleneck_channels=adapter_bottleneck_dim, input_channels=2, seq_len=32000).to(device)
def forward(self, audio_text):
encoder_output = self.musicgen.generate(**audio_text, max_new_tokens=max_new_tokens)
encoder_output = encoder_output.to('cpu')
encoder_output = torchaudio.transforms.Resample(orig_freq=encoder_output.size(2), new_freq=32000)(encoder_output)
encoder_output = encoder_output.to(self.adapter.adapter_down.weight.device)
adapted = self.adapter(encoder_output)
return adapted
#################################### Linear Adapter class Ends here ############################
"""
#################################### CNN Adapter class Begins here ############################
class Adapter(nn.Module):
def __init__(self, bottleneck_channels=192, input_channels=2, seq_len=32000, dropout_prob=0.1):
super(Adapter, self).__init__()
# Adapter Down: change kernel size from 5 to 7 (with padding=3 to maintain seq_len)
self.adapter_down = nn.Sequential(
nn.Conv1d(
in_channels=input_channels,
out_channels=bottleneck_channels,
kernel_size=7, # changed from 5 to 7
stride=1,
padding=3 # adjusted for kernel_size=7
),
nn.BatchNorm1d(bottleneck_channels),
nn.GELU()
)
# Bottleneck: use 8 ResidualBlocks (each with kernel_size=7) instead of 6.
self.bottleneck = nn.Sequential(
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=1),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=2),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=4),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=8),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=16),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=32),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=64),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=128),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=256),
# Two extra ResidualBlocks to increase depth and parameters:
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=1),
ResidualBlock(bottleneck_channels, bottleneck_channels, kernel_size=7, dilation=2),
SEBlock(bottleneck_channels),
nn.Conv1d(bottleneck_channels, bottleneck_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm1d(bottleneck_channels),
nn.GELU(),
)
# Adapter Up: also change kernel size from 5 to 7 (with padding=3)
self.adapter_up = nn.Sequential(
nn.Conv1d(
in_channels=bottleneck_channels,
out_channels=input_channels,
kernel_size=7, # changed from 5 to 7
stride=1,
padding=3 # adjusted for kernel_size=7
),
nn.BatchNorm1d(input_channels)
)
self.dropout = nn.Dropout(p=dropout_prob)
def forward(self, residual):
x = self.adapter_down(residual)
x = self.bottleneck(x)
x = self.adapter_up(x)
x = self.dropout(x + residual) # residual connection
return x
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=7, stride=1, dilation=1):
super(ResidualBlock, self).__init__()
# The larger kernel size increases parameter count.
self.conv1 = nn.Conv1d(
in_channels, out_channels, kernel_size, stride,
padding=(kernel_size // 2) * dilation, dilation=dilation
)
self.bn1 = nn.BatchNorm1d(out_channels)
self.activation = nn.GELU()
self.conv2 = nn.Conv1d(
out_channels, out_channels, kernel_size, stride,
padding=(kernel_size // 2) * dilation, dilation=dilation
)
self.bn2 = nn.BatchNorm1d(out_channels)
self.layer_norm = nn.LayerNorm(out_channels)
def forward(self, x):
residual = x # Preserve input for the skip connection
x = self.conv1(x)
x = self.bn1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x)
x = x + residual # Skip connection
return x
class SEBlock(nn.Module):
def __init__(self, channels, reduction=8):
super(SEBlock, self).__init__()
self.global_avg_pool = nn.AdaptiveAvgPool1d(1) # Squeeze
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction),
nn.ReLU(),
nn.Linear(channels // reduction, channels),
nn.Sigmoid()
)
def forward(self, x):
batch_size, channels, _ = x.shape
y = self.global_avg_pool(x).view(batch_size, channels)
y = self.fc(y).view(batch_size, channels, 1)
return x * y # Scale input features based on channel attention
# MusicGen Model with Adapter (same as in training)
class MusicGenWithAdapters(nn.Module):
def __init__(self, musicgen_model, processor, adapter_bottleneck_dim=256, device='cpu'):
super(MusicGenWithAdapters, self).__init__()
self.musicgen = musicgen_model
self.adapter = Adapter(bottleneck_channels=adapter_bottleneck_dim, input_channels=2, seq_len=32000).to(device)
def forward(self, audio_text):
encoder_output = self.musicgen.generate(**audio_text, max_new_tokens=max_new_tokens)
encoder_output = encoder_output.to('cpu')
encoder_output = torchaudio.transforms.Resample(orig_freq=encoder_output.size(2), new_freq=32000)(encoder_output)
encoder_output = encoder_output.to(self.adapter.adapter_down.weight.device)
adapted = self.adapter(encoder_output)
return adapted
#################################### CNN Adapter class Ends here ############################
"""
#################################### Transformer Adapter class Begins here ############################
class TransformerAdapter(nn.Module):
def __init__(self, input_dim=32000, bottleneck_dim=1024, num_heads=4, ff_dim=512, dropout_prob=0.1):
super(TransformerAdapter, self).__init__()
# Project input to a lower dimension.
self.down_proj = nn.Linear(input_dim, bottleneck_dim)
# Multi-head self-attention.
self.attn = nn.MultiheadAttention(embed_dim=bottleneck_dim, num_heads=num_heads, dropout=dropout_prob)
# Layer normalization before attention.
self.ln1 = nn.LayerNorm(bottleneck_dim)
# Feed-forward network.
self.ffn = nn.Sequential(
nn.Linear(bottleneck_dim, ff_dim),
nn.GELU(),
nn.Linear(ff_dim, bottleneck_dim),
nn.Dropout(dropout_prob)
)
self.ln2 = nn.LayerNorm(bottleneck_dim)
# Project back to the original dimension.
self.up_proj = nn.Linear(bottleneck_dim, input_dim)
self.dropout = nn.Dropout(dropout_prob)
def forward(self, x):
# x: (batch_size, channels, seq_len)
batch_size, channels, seq_len = x.size()
# Flatten channels into the batch dimension.
x = x.reshape(batch_size * channels, seq_len) # Use reshape instead of view.
x = self.down_proj(x) # (B * channels, bottleneck_dim)
# Add a sequence dimension for attention.
x = x.unsqueeze(0) # (1, B * channels, bottleneck_dim)
attn_out, _ = self.attn(x, x, x)
x = x + attn_out # Residual connection.
x = self.ln1(x)
# Remove the sequence dimension.
x = x.squeeze(0) # (B * channels, bottleneck_dim)
ffn_out = self.ffn(x)
x = x + ffn_out # Residual connection.
x = self.ln2(x)
x = self.up_proj(x)
x = self.dropout(x)
# Restore the original shape.
return x.reshape(batch_size, channels, seq_len) # Use reshape here as well.
# MusicGen Model with Adapter for Transformer
class MusicGenWithAdapters(nn.Module):
def __init__(self, musicgen_model, processor, adapter_bottleneck_dim, device='cpu'):
super(MusicGenWithAdapters, self).__init__()
self.musicgen = musicgen_model
# Initialize the transformer-based adapter.
self.adapter = TransformerAdapter(
input_dim=32000,
bottleneck_dim=adapter_bottleneck_dim,
num_heads=4,
ff_dim=512,
dropout_prob=0.0
).to(device)
def forward(self, audio_text):
encoder_output = self.musicgen.generate(**audio_text, max_new_tokens=128)
# Move encoder output to CPU for resampling.
encoder_output = encoder_output.to('cpu')
encoder_output = torchaudio.transforms.Resample(
orig_freq=encoder_output.size(2), new_freq=32000
)(encoder_output)
# Move back to device (using one of the adapter parameters for reference).
encoder_output = encoder_output.to(next(self.adapter.down_proj.parameters()).device)
# Expand from 1 channel to 2 channels if needed.
encoder_output = encoder_output.expand(-1, 2, -1)
# Pass through the transformer-based adapter.
adapted = self.adapter(encoder_output)
return adapted
#################################### Transformer Adapter class Ends here ############################
"""
# Function to load the model based on the configuration
def load_model(use_finetuned_model, model_save_path, device):
if use_finetuned_model:
# Load the fine-tuned model (MusicGen + Adapters)
processor = AutoProcessor.from_pretrained(pretrained_model_name)
musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device)
model_with_adapters = MusicGenWithAdapters(musicgen_model, processor, adapter_bottleneck_dim=adapter_bottleneck_dim, device=device).to(device)
# Load the state dicts for both the MusicGen model and the adapter
checkpoint = torch.load(model_save_path, map_location=device)
model_with_adapters.musicgen.load_state_dict(checkpoint['musicgen_state_dict'])
model_with_adapters.adapter.load_state_dict(checkpoint['adapter_state_dict'])
model_with_adapters.eval()
total_params = sum(p.numel() for p in model_with_adapters.parameters())
print(f"Total number of parameters in the fine-tuned model: {total_params}")
return model_with_adapters, processor # Return processor
else:
# Load the pre-trained MusicGen model
processor = AutoProcessor.from_pretrained(pretrained_model_name)
musicgen_model = MusicgenForConditionalGeneration.from_pretrained(pretrained_model_name).to(device)
musicgen_model.eval()
total_params = sum(p.numel() for p in musicgen_model.parameters())
print(f"Total number of parameters in the Original model: {total_params}")
return musicgen_model, processor # Return processor
# Function to generate audio from a text prompt
def generate_audio(model, processor, text_prompt, sample_rate=32000):
# Generate input tensor for the text prompt
input_data = processor(text=[text_prompt], return_tensors="pt").to(device)
# Generate audio using the fine-tuned or pre-trained model
if isinstance(model, MusicGenWithAdapters):
musicgen = model.musicgen
else:
musicgen = model
with torch.no_grad():
generated_output = musicgen.generate(**input_data, max_new_tokens=max_new_tokens)
waveform = generated_output.squeeze(0).cpu()
if sample_rate != 32000:
resampler = torchaudio.transforms.Resample(orig_freq=32000, new_freq=sample_rate)
waveform = resampler(waveform)
return waveform
# Main inference code
if __name__ == "__main__":
# Load the appropriate model (pre-trained or fine-tuned) based on the setting
model, processor = load_model(use_finetuned_model, model_save_path, device)
# Read the metadata.jsonl file
metadata_df = pd.read_json(r'./GeneratedAudios/Makam/Prompts.json')
# Loop over each row in the JSONL file
for index, row in metadata_df.iterrows():
text_prompt = row['captions']
print(f"Generating audio for prompt: {text_prompt}")
# Generate audio
waveform = generate_audio(
model,
processor,
text_prompt,
sample_rate=sample_rate
)
# Prepare the output paths
output_audio_filename = os.path.basename(row['location'])
output_audio_path = os.path.join("./GeneratedAudios/Makam/CNN/", output_audio_filename)
# Ensure the output directory exists
os.makedirs(os.path.dirname(output_audio_path), exist_ok=True)
# Save the generated audio
torchaudio.save(output_audio_path, waveform, sample_rate)
print(f"Generated audio saved at {output_audio_path}")
# Save the waveform graph
AudioWaveform_graph_filename = os.path.splitext(output_audio_filename)[0] + '.jpeg'
AudioWaveform_graph_path = os.path.join("../Random/", AudioWaveform_graph_filename)
# Ensure the output directory exists
os.makedirs(os.path.dirname(AudioWaveform_graph_path), exist_ok=True)
plt.figure(figsize=(12, 4))
plt.plot(waveform.t().numpy())
plt.savefig(AudioWaveform_graph_path)
plt.close() # Close the figure to free memory
print(f"Waveform graph saved at {AudioWaveform_graph_path}")