Spaces:
Sleeping
Sleeping
File size: 15,463 Bytes
827b824 41f5dcd a1d5625 41f5dcd 827b824 a6e329e 827b824 5a3fac8 827b824 66d9b0c 827b824 66d9b0c 827b824 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 |
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import csv
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# --- Configuration & Constants ---
MODEL_BASE_PATH = "MQGAN/models"
PREENC_PATH = os.path.join(MODEL_BASE_PATH, "preenc_musicv3.pth")
MUSICLSTM_PATH = os.path.join(MODEL_BASE_PATH, "musiclstm_1.pt")
ISTFTNET_PATH = os.path.join(MODEL_BASE_PATH, "istftnet_music_ft")
GENRE_CSV_PATH = os.path.join(MODEL_BASE_PATH, "genre_ids.csv")
VOCAB_SIZE = 1010
NUM_GENRES = 250 # This might be derived from genre_ids.csv or fixed if model expects it
BOS_ID = 1001
PAD_ID = 1002
TOKENS_PER_SECOND = 86
SAMPLING_RATE = 44100
# Use Agg backend for Matplotlib in headless environments (like Gradio)
matplotlib.use('Agg')
# --- Attempt to import actual MQGAN modules ---
from MQGAN import get_pre_encoder as actual_get_pre_encoder
from MQGAN import ISTFTNetFE as actual_ISTFTNetFE
from MQGAN import MusicLSTM as actual_MusicLSTM
# If imports succeed, replace placeholders
get_pre_encoder = actual_get_pre_encoder
ISTFTNetFE = actual_ISTFTNetFE
MusicLSTM = actual_MusicLSTM
# --- Global Model Loading ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Check if model paths exist
if not os.path.exists(PREENC_PATH):
print(f"Warning: Pre-encoder model not found at {PREENC_PATH}")
if not os.path.exists(MUSICLSTM_PATH):
print(f"Warning: MusicLSTM model not found at {MUSICLSTM_PATH}")
if not (os.path.exists(ISTFTNET_PATH + ".ts") or os.path.exists(ISTFTNET_PATH + ".pth")): # .ts or .pth for vocoder
print(f"Warning: Vocoder model not found at {ISTFTNET_PATH}. (.ts or .pth)")
pre_enc = get_pre_encoder(PREENC_PATH, device, channels=[192, 768, 1024, 1024], kernel_sizes=[3, 5, 7, 11],
mel_channels=160)
# Determine NUM_GENRES from genre_ids.csv if possible, otherwise use fixed
# For now, using the fixed NUM_GENRES. The model needs to be trained with this.
model = MusicLSTM(vocab_size=VOCAB_SIZE, num_genres=NUM_GENRES, pad_id=PAD_ID)
chkp = torch.load(MUSICLSTM_PATH, map_location=device, weights_only=False)
model.load_state_dict(chkp["model_state_dict"])
model = model.to(device).eval()
vocoder = ISTFTNetFE(None, None) # Adjust constructor if needed
vocoder.load_ts(ISTFTNET_PATH, device) # load_ts might expect a directory/prefix
vocoder = vocoder.to(device).eval()
MODELS_LOADED = True
print("Models loaded successfully (or placeholders initialized).")
# --- Genre Loading ---
def load_genres(csv_filepath):
genre_map = {}
genre_names = []
if not os.path.exists(csv_filepath):
print(f"Warning: Genre CSV file not found at {csv_filepath}. Using default genres.")
# Provide a few default genres if CSV is missing
default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50} # Example
genre_map = default_genres
genre_names = list(default_genres.keys())
return genre_map, genre_names
try:
with open(csv_filepath, mode='r', encoding='utf-8') as infile:
reader = csv.reader(infile)
next(reader) # Skip header
for row in reader:
if len(row) == 2:
genre_id, genre_name = int(row[0]), row[1]
genre_map[genre_name] = genre_id
genre_names.append(genre_name)
# Update NUM_GENRES based on loaded genres if model allows dynamic genre count
# global NUM_GENRES
# NUM_GENRES = len(genre_map) # This could be problematic if model has fixed NUM_GENRES
print(f"Loaded {len(genre_names)} genres from {csv_filepath}")
except Exception as e:
print(f"Error loading genres from {csv_filepath}: {e}. Using default genres.")
default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50}
genre_map = default_genres
genre_names = list(default_genres.keys())
return genre_map, genre_names
genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH)
if not genre_name_list: # Fallback if loading fails completely
genre_name_list = ["Rock"] # Default if CSV is empty or unreadable
genre_to_id_map = {"Rock": 176}
# --- Helper Functions (from user code) ---
def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9, filter_value: float = -float('Inf')) -> torch.Tensor:
if top_p is None or top_p == 1.0: # Added top_p == 1.0 condition
return logits
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
# Ensure logits is 2D for scatter_
if logits.dim() == 1:
logits = logits.unsqueeze(0)
# Ensure indices_to_remove is also 2D and matches logits batch size for scatter_
# Or, if logits was originally 1D, apply to that.
# The original code assumes logits is (batch_size, vocab_size)
# and indices_to_remove is derived from that.
# If logits is (1, vocab_size), indices_to_remove will be correctly shaped by sorted_indices.
# Simpler way to apply filter for batch_size=1 case:
logits_clone = logits.clone()
if logits_clone.shape[0] == 1 and indices_to_remove.numel() > 0: # Check if there's anything to remove
logits_clone[0, indices_to_remove] = filter_value
elif logits_clone.shape[0] > 1: # For batch_size > 1, scatter needs careful handling
# This part might need adjustment if batch_size > 1 is used in generation.
# The original code implies batch_size = 1 for generation.
for i in range(logits_clone.shape[0]):
# This assumes sorted_indices_to_remove was calculated per batch item if logits was batched
# However, top_p_filtering as provided seems to operate on a single logit distribution (batch_size=1)
# If logits is (batch_size, vocab_size), then sorted_indices will be (batch_size, vocab_size)
# and indices_to_remove would be (batch_size, num_items_to_remove_per_batch_item)
# The current implementation of top_p seems to assume a single batch item or applies the same mask.
# For safety, let's assume it works on (1, vocab_size) or (vocab_size)
current_indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits_clone[i, current_indices_to_remove] = filter_value
return logits_clone
@torch.no_grad()
def generate_music_tokens(
music_model, # Renamed from 'model' to avoid conflict with global 'model'
initial_bos: int,
genre_id: int,
n_tokens_to_generate: int,
temperature: float = 0.75,
top_p: float = 0.9,
current_device: torch.device = None, # Renamed from 'device'
) -> torch.LongTensor:
if current_device is None:
current_device = next(music_model.parameters()).device
music_model.eval()
num_layers = music_model.lstm.num_layers
hidden_size = music_model.lstm.hidden_size
h = torch.zeros(num_layers, 1, hidden_size, device=current_device)
c = torch.zeros(num_layers, 1, hidden_size, device=current_device)
genre_tensor = torch.tensor([genre_id], device=current_device, dtype=torch.long)
# Check if genre_id is within the valid range for the embedding layer
if not (0 <= genre_id < music_model.genre_emb.num_embeddings):
raise ValueError(f"genre_id {genre_id} is out of bounds for genre_emb layer "
f"with {music_model.genre_emb.num_embeddings} embeddings.")
genre_emb = music_model.genre_emb(genre_tensor)
token = initial_bos
generated = [token]
for _ in range(n_tokens_to_generate):
tok_e = music_model.tok_emb(torch.tensor([token], device=current_device))
inp = (tok_e + genre_emb).unsqueeze(1) # LSTM expects (seq_len, batch, features) -> (1,1,emb_dim)
out, (h, c) = music_model.lstm(inp, (h, c))
out = music_model.dropout(out.squeeze(1)) # (1, hidden_size)
logits = music_model.proj(out) / max(temperature, 1e-8) # (1, vocab_size)
filtered_logits = top_p_filtering(logits, top_p=top_p)
probs = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
generated.append(next_token)
token = next_token
return torch.tensor([generated], device=current_device, dtype=torch.long)
# --- Main Gradio Inference Function ---
def generate_audio_and_plot(genre_name_selected, duration_seconds, temperature, top_p_val):
if not MODELS_LOADED or not all([pre_enc, model, vocoder]):
# Create a dummy plot and audio to indicate failure
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "Models not loaded. Cannot generate.", ha='center', va='center')
dummy_audio = np.zeros(SAMPLING_RATE) # 1 second of silence
return fig, (SAMPLING_RATE, dummy_audio), "Error: Models not loaded. Check console."
try:
genre_id_selected = genre_to_id_map.get(genre_name_selected)
if genre_id_selected is None:
# Fallback if genre name not in map (should not happen with dropdown)
return None, None, f"Error: Genre '{genre_name_selected}' not found."
n_tokens = int(TOKENS_PER_SECOND * duration_seconds)
print(
f"Generating for Genre: {genre_name_selected} (ID: {genre_id_selected}), Duration: {duration_seconds}s, Temp: {temperature}, Top-P: {top_p_val}")
indices = generate_music_tokens(
model, BOS_ID, genre_id_selected, n_tokens, temperature, top_p_val, device
) # (1, T_tokens)
mel = pre_enc.decode(indices) # Expected: (1, T_mel_frames, Mel_channels)
# Transpose mel for vocoder if needed. User code has mel.transpose(1,2)
# If pre_enc.decode gives (1, T_frames, Mel_channels) and vocoder expects (1, Mel_channels, T_frames)
mel_for_vocoder = mel.transpose(1, 2) # (1, Mel_channels, T_mel_frames)
# Vocoder inference
# The output of vocoder.infer should be a 1D numpy array or torch tensor (waveform)
audio_waveform_tensor = vocoder.infer(mel_for_vocoder) # This is a placeholder output
if isinstance(audio_waveform_tensor, torch.Tensor):
audio_waveform_np = audio_waveform_tensor.squeeze().cpu().numpy()
elif isinstance(audio_waveform_tensor, np.ndarray):
audio_waveform_np = audio_waveform_tensor.squeeze()
else:
raise TypeError(f"Vocoder output type not recognized: {type(audio_waveform_tensor)}")
# Ensure it's float32 for audio component, and normalize if necessary
if audio_waveform_np.dtype != np.float32:
audio_waveform_np = audio_waveform_np.astype(np.float32)
# Optional: Normalize audio to [-1, 1] if not already
# max_val = np.max(np.abs(audio_waveform_np))
# if max_val > 0:
# audio_waveform_np = audio_waveform_np / max_val
# Create Mel Spectrogram Plot
mel_cpu = mel.squeeze(0).detach().cpu().numpy() # (T_mel_frames, Mel_channels)
fig = plt.figure(figsize=(10, 4))
plt.imshow(mel_cpu.T, origin='lower', aspect='auto', interpolation='nearest')
plt.xlabel('Time Frame')
plt.ylabel('Mel Channel')
plt.title(f'Mel Spectrogram: {genre_name_selected}, {duration_seconds}s')
plt.colorbar(label='Amplitude')
plt.tight_layout()
# plt.close(fig) # Close the figure to free memory if not returning fig object directly
status_message = f"Generated {duration_seconds}s of {genre_name_selected} music."
print(status_message)
return fig, (SAMPLING_RATE, audio_waveform_np), status_message
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Return empty plot/audio on error
fig, ax = plt.subplots()
ax.text(0.5, 0.5, f"Error: {e}", ha='center', va='center', wrap=True)
empty_audio = np.zeros(SAMPLING_RATE)
return fig, (SAMPLING_RATE, empty_audio), f"Error: {e}"
# --- Gradio UI Definition ---
iface = gr.Interface(
fn=generate_audio_and_plot,
inputs=[
gr.Dropdown(choices=genre_name_list, value=genre_name_list[0] if genre_name_list else None, label="Genre"),
gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Duration (seconds)"),
gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.90, label="Temperature"), # User example used 0.95
gr.Slider(minimum=0.5, maximum=1.0, step=0.01, value=0.95, label="Top-P Sampling") # User example used 0.95
],
outputs=[
gr.Plot(label="Mel Spectrogram"),
gr.Audio(label="Generated Audio", type="numpy"), # type="numpy" expects (sr, data)
gr.Textbox(label="Status")
],
title="MusicLSTM + MQGAN Music Generator sample",
description="Genre-conditioned music generation with MusicLSTM. Select a genre, duration, and sampling parameters. Note that outputs may be nonsense depending on genre and luck.",
allow_flagging="never"
)
if __name__ == "__main__":
if not MODELS_LOADED:
print("\n--- WARNING: MODELS DID NOT LOAD ---")
print("The Gradio app will launch, but generation will fail.")
print("Please check model paths, MQGAN class definitions/imports, and any errors above.")
# Create a dummy genre_ids.csv if it doesn't exist for the app to run
if not os.path.exists(GENRE_CSV_PATH):
print(f"Creating a dummy '{GENRE_CSV_PATH}' for demonstration.")
with open(GENRE_CSV_PATH, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["genre_id", "genre_name"])
writer.writerow([0, "60s"])
writer.writerow([176, "Rock"]) # From user's example
writer.writerow([48, "Electronic"]) # From user's example
writer.writerow([10, "Ambient"])
writer.writerow([20, "Classical"])
# Reload genres if we just created the file
genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH)
# Update dropdown choices in the interface if it was already defined (it is)
# This is a bit hacky; ideally, interface definition is deferred or choices are dynamic.
# For this script structure, re-assigning inputs might be tricky.
# A cleaner way would be to define inputs within a function called by iface.launch()
# or use gr.update in a more complex setup.
# For simplicity, we rely on the initial load or the dummy file creation before iface.
# If the file was created, the dropdown will use the new list if this script is re-run.
# Or, update the component directly if Gradio allows after definition (less common for choices)
# For now, this means if the CSV is missing, it's created, and on the *next* run it's used.
# To make it immediate, we could redefine iface or use dynamic updates.
# The current load_genres handles fallback well enough.
iface.launch(share=True)
|