MusicLSTMDemo / app.py
ZDisket
fix try catch
66d9b0c
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import csv
import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# --- Configuration & Constants ---
MODEL_BASE_PATH = "MQGAN/models"
PREENC_PATH = os.path.join(MODEL_BASE_PATH, "preenc_musicv3.pth")
MUSICLSTM_PATH = os.path.join(MODEL_BASE_PATH, "musiclstm_1.pt")
ISTFTNET_PATH = os.path.join(MODEL_BASE_PATH, "istftnet_music_ft")
GENRE_CSV_PATH = os.path.join(MODEL_BASE_PATH, "genre_ids.csv")
VOCAB_SIZE = 1010
NUM_GENRES = 250 # This might be derived from genre_ids.csv or fixed if model expects it
BOS_ID = 1001
PAD_ID = 1002
TOKENS_PER_SECOND = 86
SAMPLING_RATE = 44100
# Use Agg backend for Matplotlib in headless environments (like Gradio)
matplotlib.use('Agg')
# --- Attempt to import actual MQGAN modules ---
from MQGAN import get_pre_encoder as actual_get_pre_encoder
from MQGAN import ISTFTNetFE as actual_ISTFTNetFE
from MQGAN import MusicLSTM as actual_MusicLSTM
# If imports succeed, replace placeholders
get_pre_encoder = actual_get_pre_encoder
ISTFTNetFE = actual_ISTFTNetFE
MusicLSTM = actual_MusicLSTM
# --- Global Model Loading ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Check if model paths exist
if not os.path.exists(PREENC_PATH):
print(f"Warning: Pre-encoder model not found at {PREENC_PATH}")
if not os.path.exists(MUSICLSTM_PATH):
print(f"Warning: MusicLSTM model not found at {MUSICLSTM_PATH}")
if not (os.path.exists(ISTFTNET_PATH + ".ts") or os.path.exists(ISTFTNET_PATH + ".pth")): # .ts or .pth for vocoder
print(f"Warning: Vocoder model not found at {ISTFTNET_PATH}. (.ts or .pth)")
pre_enc = get_pre_encoder(PREENC_PATH, device, channels=[192, 768, 1024, 1024], kernel_sizes=[3, 5, 7, 11],
mel_channels=160)
# Determine NUM_GENRES from genre_ids.csv if possible, otherwise use fixed
# For now, using the fixed NUM_GENRES. The model needs to be trained with this.
model = MusicLSTM(vocab_size=VOCAB_SIZE, num_genres=NUM_GENRES, pad_id=PAD_ID)
chkp = torch.load(MUSICLSTM_PATH, map_location=device, weights_only=False)
model.load_state_dict(chkp["model_state_dict"])
model = model.to(device).eval()
vocoder = ISTFTNetFE(None, None) # Adjust constructor if needed
vocoder.load_ts(ISTFTNET_PATH, device) # load_ts might expect a directory/prefix
vocoder = vocoder.to(device).eval()
MODELS_LOADED = True
print("Models loaded successfully (or placeholders initialized).")
# --- Genre Loading ---
def load_genres(csv_filepath):
genre_map = {}
genre_names = []
if not os.path.exists(csv_filepath):
print(f"Warning: Genre CSV file not found at {csv_filepath}. Using default genres.")
# Provide a few default genres if CSV is missing
default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50} # Example
genre_map = default_genres
genre_names = list(default_genres.keys())
return genre_map, genre_names
try:
with open(csv_filepath, mode='r', encoding='utf-8') as infile:
reader = csv.reader(infile)
next(reader) # Skip header
for row in reader:
if len(row) == 2:
genre_id, genre_name = int(row[0]), row[1]
genre_map[genre_name] = genre_id
genre_names.append(genre_name)
# Update NUM_GENRES based on loaded genres if model allows dynamic genre count
# global NUM_GENRES
# NUM_GENRES = len(genre_map) # This could be problematic if model has fixed NUM_GENRES
print(f"Loaded {len(genre_names)} genres from {csv_filepath}")
except Exception as e:
print(f"Error loading genres from {csv_filepath}: {e}. Using default genres.")
default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50}
genre_map = default_genres
genre_names = list(default_genres.keys())
return genre_map, genre_names
genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH)
if not genre_name_list: # Fallback if loading fails completely
genre_name_list = ["Rock"] # Default if CSV is empty or unreadable
genre_to_id_map = {"Rock": 176}
# --- Helper Functions (from user code) ---
def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9, filter_value: float = -float('Inf')) -> torch.Tensor:
if top_p is None or top_p == 1.0: # Added top_p == 1.0 condition
return logits
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = False
indices_to_remove = sorted_indices[sorted_indices_to_remove]
# Ensure logits is 2D for scatter_
if logits.dim() == 1:
logits = logits.unsqueeze(0)
# Ensure indices_to_remove is also 2D and matches logits batch size for scatter_
# Or, if logits was originally 1D, apply to that.
# The original code assumes logits is (batch_size, vocab_size)
# and indices_to_remove is derived from that.
# If logits is (1, vocab_size), indices_to_remove will be correctly shaped by sorted_indices.
# Simpler way to apply filter for batch_size=1 case:
logits_clone = logits.clone()
if logits_clone.shape[0] == 1 and indices_to_remove.numel() > 0: # Check if there's anything to remove
logits_clone[0, indices_to_remove] = filter_value
elif logits_clone.shape[0] > 1: # For batch_size > 1, scatter needs careful handling
# This part might need adjustment if batch_size > 1 is used in generation.
# The original code implies batch_size = 1 for generation.
for i in range(logits_clone.shape[0]):
# This assumes sorted_indices_to_remove was calculated per batch item if logits was batched
# However, top_p_filtering as provided seems to operate on a single logit distribution (batch_size=1)
# If logits is (batch_size, vocab_size), then sorted_indices will be (batch_size, vocab_size)
# and indices_to_remove would be (batch_size, num_items_to_remove_per_batch_item)
# The current implementation of top_p seems to assume a single batch item or applies the same mask.
# For safety, let's assume it works on (1, vocab_size) or (vocab_size)
current_indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits_clone[i, current_indices_to_remove] = filter_value
return logits_clone
@torch.no_grad()
def generate_music_tokens(
music_model, # Renamed from 'model' to avoid conflict with global 'model'
initial_bos: int,
genre_id: int,
n_tokens_to_generate: int,
temperature: float = 0.75,
top_p: float = 0.9,
current_device: torch.device = None, # Renamed from 'device'
) -> torch.LongTensor:
if current_device is None:
current_device = next(music_model.parameters()).device
music_model.eval()
num_layers = music_model.lstm.num_layers
hidden_size = music_model.lstm.hidden_size
h = torch.zeros(num_layers, 1, hidden_size, device=current_device)
c = torch.zeros(num_layers, 1, hidden_size, device=current_device)
genre_tensor = torch.tensor([genre_id], device=current_device, dtype=torch.long)
# Check if genre_id is within the valid range for the embedding layer
if not (0 <= genre_id < music_model.genre_emb.num_embeddings):
raise ValueError(f"genre_id {genre_id} is out of bounds for genre_emb layer "
f"with {music_model.genre_emb.num_embeddings} embeddings.")
genre_emb = music_model.genre_emb(genre_tensor)
token = initial_bos
generated = [token]
for _ in range(n_tokens_to_generate):
tok_e = music_model.tok_emb(torch.tensor([token], device=current_device))
inp = (tok_e + genre_emb).unsqueeze(1) # LSTM expects (seq_len, batch, features) -> (1,1,emb_dim)
out, (h, c) = music_model.lstm(inp, (h, c))
out = music_model.dropout(out.squeeze(1)) # (1, hidden_size)
logits = music_model.proj(out) / max(temperature, 1e-8) # (1, vocab_size)
filtered_logits = top_p_filtering(logits, top_p=top_p)
probs = F.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).item()
generated.append(next_token)
token = next_token
return torch.tensor([generated], device=current_device, dtype=torch.long)
# --- Main Gradio Inference Function ---
def generate_audio_and_plot(genre_name_selected, duration_seconds, temperature, top_p_val):
if not MODELS_LOADED or not all([pre_enc, model, vocoder]):
# Create a dummy plot and audio to indicate failure
fig, ax = plt.subplots()
ax.text(0.5, 0.5, "Models not loaded. Cannot generate.", ha='center', va='center')
dummy_audio = np.zeros(SAMPLING_RATE) # 1 second of silence
return fig, (SAMPLING_RATE, dummy_audio), "Error: Models not loaded. Check console."
try:
genre_id_selected = genre_to_id_map.get(genre_name_selected)
if genre_id_selected is None:
# Fallback if genre name not in map (should not happen with dropdown)
return None, None, f"Error: Genre '{genre_name_selected}' not found."
n_tokens = int(TOKENS_PER_SECOND * duration_seconds)
print(
f"Generating for Genre: {genre_name_selected} (ID: {genre_id_selected}), Duration: {duration_seconds}s, Temp: {temperature}, Top-P: {top_p_val}")
indices = generate_music_tokens(
model, BOS_ID, genre_id_selected, n_tokens, temperature, top_p_val, device
) # (1, T_tokens)
mel = pre_enc.decode(indices) # Expected: (1, T_mel_frames, Mel_channels)
# Transpose mel for vocoder if needed. User code has mel.transpose(1,2)
# If pre_enc.decode gives (1, T_frames, Mel_channels) and vocoder expects (1, Mel_channels, T_frames)
mel_for_vocoder = mel.transpose(1, 2) # (1, Mel_channels, T_mel_frames)
# Vocoder inference
# The output of vocoder.infer should be a 1D numpy array or torch tensor (waveform)
audio_waveform_tensor = vocoder.infer(mel_for_vocoder) # This is a placeholder output
if isinstance(audio_waveform_tensor, torch.Tensor):
audio_waveform_np = audio_waveform_tensor.squeeze().cpu().numpy()
elif isinstance(audio_waveform_tensor, np.ndarray):
audio_waveform_np = audio_waveform_tensor.squeeze()
else:
raise TypeError(f"Vocoder output type not recognized: {type(audio_waveform_tensor)}")
# Ensure it's float32 for audio component, and normalize if necessary
if audio_waveform_np.dtype != np.float32:
audio_waveform_np = audio_waveform_np.astype(np.float32)
# Optional: Normalize audio to [-1, 1] if not already
# max_val = np.max(np.abs(audio_waveform_np))
# if max_val > 0:
# audio_waveform_np = audio_waveform_np / max_val
# Create Mel Spectrogram Plot
mel_cpu = mel.squeeze(0).detach().cpu().numpy() # (T_mel_frames, Mel_channels)
fig = plt.figure(figsize=(10, 4))
plt.imshow(mel_cpu.T, origin='lower', aspect='auto', interpolation='nearest')
plt.xlabel('Time Frame')
plt.ylabel('Mel Channel')
plt.title(f'Mel Spectrogram: {genre_name_selected}, {duration_seconds}s')
plt.colorbar(label='Amplitude')
plt.tight_layout()
# plt.close(fig) # Close the figure to free memory if not returning fig object directly
status_message = f"Generated {duration_seconds}s of {genre_name_selected} music."
print(status_message)
return fig, (SAMPLING_RATE, audio_waveform_np), status_message
except Exception as e:
print(f"Error during generation: {e}")
import traceback
traceback.print_exc()
# Return empty plot/audio on error
fig, ax = plt.subplots()
ax.text(0.5, 0.5, f"Error: {e}", ha='center', va='center', wrap=True)
empty_audio = np.zeros(SAMPLING_RATE)
return fig, (SAMPLING_RATE, empty_audio), f"Error: {e}"
# --- Gradio UI Definition ---
iface = gr.Interface(
fn=generate_audio_and_plot,
inputs=[
gr.Dropdown(choices=genre_name_list, value=genre_name_list[0] if genre_name_list else None, label="Genre"),
gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Duration (seconds)"),
gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.90, label="Temperature"), # User example used 0.95
gr.Slider(minimum=0.5, maximum=1.0, step=0.01, value=0.95, label="Top-P Sampling") # User example used 0.95
],
outputs=[
gr.Plot(label="Mel Spectrogram"),
gr.Audio(label="Generated Audio", type="numpy"), # type="numpy" expects (sr, data)
gr.Textbox(label="Status")
],
title="MusicLSTM + MQGAN Music Generator sample",
description="Genre-conditioned music generation with MusicLSTM. Select a genre, duration, and sampling parameters. Note that outputs may be nonsense depending on genre and luck.",
allow_flagging="never"
)
if __name__ == "__main__":
if not MODELS_LOADED:
print("\n--- WARNING: MODELS DID NOT LOAD ---")
print("The Gradio app will launch, but generation will fail.")
print("Please check model paths, MQGAN class definitions/imports, and any errors above.")
# Create a dummy genre_ids.csv if it doesn't exist for the app to run
if not os.path.exists(GENRE_CSV_PATH):
print(f"Creating a dummy '{GENRE_CSV_PATH}' for demonstration.")
with open(GENRE_CSV_PATH, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["genre_id", "genre_name"])
writer.writerow([0, "60s"])
writer.writerow([176, "Rock"]) # From user's example
writer.writerow([48, "Electronic"]) # From user's example
writer.writerow([10, "Ambient"])
writer.writerow([20, "Classical"])
# Reload genres if we just created the file
genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH)
# Update dropdown choices in the interface if it was already defined (it is)
# This is a bit hacky; ideally, interface definition is deferred or choices are dynamic.
# For this script structure, re-assigning inputs might be tricky.
# A cleaner way would be to define inputs within a function called by iface.launch()
# or use gr.update in a more complex setup.
# For simplicity, we rely on the initial load or the dummy file creation before iface.
# If the file was created, the dropdown will use the new list if this script is re-run.
# Or, update the component directly if Gradio allows after definition (less common for choices)
# For now, this means if the CSV is missing, it's created, and on the *next* run it's used.
# To make it immediate, we could redefine iface or use dynamic updates.
# The current load_genres handles fallback well enough.
iface.launch(share=True)