import gradio as gr import torch import torch.nn.functional as F import numpy as np import matplotlib.pyplot as plt import matplotlib import csv import os import sys sys.path.append(os.path.dirname(os.path.abspath(__file__))) # --- Configuration & Constants --- MODEL_BASE_PATH = "MQGAN/models" PREENC_PATH = os.path.join(MODEL_BASE_PATH, "preenc_musicv3.pth") MUSICLSTM_PATH = os.path.join(MODEL_BASE_PATH, "musiclstm_1.pt") ISTFTNET_PATH = os.path.join(MODEL_BASE_PATH, "istftnet_music_ft") GENRE_CSV_PATH = os.path.join(MODEL_BASE_PATH, "genre_ids.csv") VOCAB_SIZE = 1010 NUM_GENRES = 250 # This might be derived from genre_ids.csv or fixed if model expects it BOS_ID = 1001 PAD_ID = 1002 TOKENS_PER_SECOND = 86 SAMPLING_RATE = 44100 # Use Agg backend for Matplotlib in headless environments (like Gradio) matplotlib.use('Agg') # --- Attempt to import actual MQGAN modules --- from MQGAN import get_pre_encoder as actual_get_pre_encoder from MQGAN import ISTFTNetFE as actual_ISTFTNetFE from MQGAN import MusicLSTM as actual_MusicLSTM # If imports succeed, replace placeholders get_pre_encoder = actual_get_pre_encoder ISTFTNetFE = actual_ISTFTNetFE MusicLSTM = actual_MusicLSTM # --- Global Model Loading --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Check if model paths exist if not os.path.exists(PREENC_PATH): print(f"Warning: Pre-encoder model not found at {PREENC_PATH}") if not os.path.exists(MUSICLSTM_PATH): print(f"Warning: MusicLSTM model not found at {MUSICLSTM_PATH}") if not (os.path.exists(ISTFTNET_PATH + ".ts") or os.path.exists(ISTFTNET_PATH + ".pth")): # .ts or .pth for vocoder print(f"Warning: Vocoder model not found at {ISTFTNET_PATH}. (.ts or .pth)") pre_enc = get_pre_encoder(PREENC_PATH, device, channels=[192, 768, 1024, 1024], kernel_sizes=[3, 5, 7, 11], mel_channels=160) # Determine NUM_GENRES from genre_ids.csv if possible, otherwise use fixed # For now, using the fixed NUM_GENRES. The model needs to be trained with this. model = MusicLSTM(vocab_size=VOCAB_SIZE, num_genres=NUM_GENRES, pad_id=PAD_ID) chkp = torch.load(MUSICLSTM_PATH, map_location=device, weights_only=False) model.load_state_dict(chkp["model_state_dict"]) model = model.to(device).eval() vocoder = ISTFTNetFE(None, None) # Adjust constructor if needed vocoder.load_ts(ISTFTNET_PATH, device) # load_ts might expect a directory/prefix vocoder = vocoder.to(device).eval() MODELS_LOADED = True print("Models loaded successfully (or placeholders initialized).") # --- Genre Loading --- def load_genres(csv_filepath): genre_map = {} genre_names = [] if not os.path.exists(csv_filepath): print(f"Warning: Genre CSV file not found at {csv_filepath}. Using default genres.") # Provide a few default genres if CSV is missing default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50} # Example genre_map = default_genres genre_names = list(default_genres.keys()) return genre_map, genre_names try: with open(csv_filepath, mode='r', encoding='utf-8') as infile: reader = csv.reader(infile) next(reader) # Skip header for row in reader: if len(row) == 2: genre_id, genre_name = int(row[0]), row[1] genre_map[genre_name] = genre_id genre_names.append(genre_name) # Update NUM_GENRES based on loaded genres if model allows dynamic genre count # global NUM_GENRES # NUM_GENRES = len(genre_map) # This could be problematic if model has fixed NUM_GENRES print(f"Loaded {len(genre_names)} genres from {csv_filepath}") except Exception as e: print(f"Error loading genres from {csv_filepath}: {e}. Using default genres.") default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50} genre_map = default_genres genre_names = list(default_genres.keys()) return genre_map, genre_names genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH) if not genre_name_list: # Fallback if loading fails completely genre_name_list = ["Rock"] # Default if CSV is empty or unreadable genre_to_id_map = {"Rock": 176} # --- Helper Functions (from user code) --- def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9, filter_value: float = -float('Inf')) -> torch.Tensor: if top_p is None or top_p == 1.0: # Added top_p == 1.0 condition return logits sorted_logits, sorted_indices = torch.sort(logits, descending=True) sorted_probs = F.softmax(sorted_logits, dim=-1) cumulative_probs = sorted_probs.cumsum(dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False indices_to_remove = sorted_indices[sorted_indices_to_remove] # Ensure logits is 2D for scatter_ if logits.dim() == 1: logits = logits.unsqueeze(0) # Ensure indices_to_remove is also 2D and matches logits batch size for scatter_ # Or, if logits was originally 1D, apply to that. # The original code assumes logits is (batch_size, vocab_size) # and indices_to_remove is derived from that. # If logits is (1, vocab_size), indices_to_remove will be correctly shaped by sorted_indices. # Simpler way to apply filter for batch_size=1 case: logits_clone = logits.clone() if logits_clone.shape[0] == 1 and indices_to_remove.numel() > 0: # Check if there's anything to remove logits_clone[0, indices_to_remove] = filter_value elif logits_clone.shape[0] > 1: # For batch_size > 1, scatter needs careful handling # This part might need adjustment if batch_size > 1 is used in generation. # The original code implies batch_size = 1 for generation. for i in range(logits_clone.shape[0]): # This assumes sorted_indices_to_remove was calculated per batch item if logits was batched # However, top_p_filtering as provided seems to operate on a single logit distribution (batch_size=1) # If logits is (batch_size, vocab_size), then sorted_indices will be (batch_size, vocab_size) # and indices_to_remove would be (batch_size, num_items_to_remove_per_batch_item) # The current implementation of top_p seems to assume a single batch item or applies the same mask. # For safety, let's assume it works on (1, vocab_size) or (vocab_size) current_indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] logits_clone[i, current_indices_to_remove] = filter_value return logits_clone @torch.no_grad() def generate_music_tokens( music_model, # Renamed from 'model' to avoid conflict with global 'model' initial_bos: int, genre_id: int, n_tokens_to_generate: int, temperature: float = 0.75, top_p: float = 0.9, current_device: torch.device = None, # Renamed from 'device' ) -> torch.LongTensor: if current_device is None: current_device = next(music_model.parameters()).device music_model.eval() num_layers = music_model.lstm.num_layers hidden_size = music_model.lstm.hidden_size h = torch.zeros(num_layers, 1, hidden_size, device=current_device) c = torch.zeros(num_layers, 1, hidden_size, device=current_device) genre_tensor = torch.tensor([genre_id], device=current_device, dtype=torch.long) # Check if genre_id is within the valid range for the embedding layer if not (0 <= genre_id < music_model.genre_emb.num_embeddings): raise ValueError(f"genre_id {genre_id} is out of bounds for genre_emb layer " f"with {music_model.genre_emb.num_embeddings} embeddings.") genre_emb = music_model.genre_emb(genre_tensor) token = initial_bos generated = [token] for _ in range(n_tokens_to_generate): tok_e = music_model.tok_emb(torch.tensor([token], device=current_device)) inp = (tok_e + genre_emb).unsqueeze(1) # LSTM expects (seq_len, batch, features) -> (1,1,emb_dim) out, (h, c) = music_model.lstm(inp, (h, c)) out = music_model.dropout(out.squeeze(1)) # (1, hidden_size) logits = music_model.proj(out) / max(temperature, 1e-8) # (1, vocab_size) filtered_logits = top_p_filtering(logits, top_p=top_p) probs = F.softmax(filtered_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1).item() generated.append(next_token) token = next_token return torch.tensor([generated], device=current_device, dtype=torch.long) # --- Main Gradio Inference Function --- def generate_audio_and_plot(genre_name_selected, duration_seconds, temperature, top_p_val): if not MODELS_LOADED or not all([pre_enc, model, vocoder]): # Create a dummy plot and audio to indicate failure fig, ax = plt.subplots() ax.text(0.5, 0.5, "Models not loaded. Cannot generate.", ha='center', va='center') dummy_audio = np.zeros(SAMPLING_RATE) # 1 second of silence return fig, (SAMPLING_RATE, dummy_audio), "Error: Models not loaded. Check console." try: genre_id_selected = genre_to_id_map.get(genre_name_selected) if genre_id_selected is None: # Fallback if genre name not in map (should not happen with dropdown) return None, None, f"Error: Genre '{genre_name_selected}' not found." n_tokens = int(TOKENS_PER_SECOND * duration_seconds) print( f"Generating for Genre: {genre_name_selected} (ID: {genre_id_selected}), Duration: {duration_seconds}s, Temp: {temperature}, Top-P: {top_p_val}") indices = generate_music_tokens( model, BOS_ID, genre_id_selected, n_tokens, temperature, top_p_val, device ) # (1, T_tokens) mel = pre_enc.decode(indices) # Expected: (1, T_mel_frames, Mel_channels) # Transpose mel for vocoder if needed. User code has mel.transpose(1,2) # If pre_enc.decode gives (1, T_frames, Mel_channels) and vocoder expects (1, Mel_channels, T_frames) mel_for_vocoder = mel.transpose(1, 2) # (1, Mel_channels, T_mel_frames) # Vocoder inference # The output of vocoder.infer should be a 1D numpy array or torch tensor (waveform) audio_waveform_tensor = vocoder.infer(mel_for_vocoder) # This is a placeholder output if isinstance(audio_waveform_tensor, torch.Tensor): audio_waveform_np = audio_waveform_tensor.squeeze().cpu().numpy() elif isinstance(audio_waveform_tensor, np.ndarray): audio_waveform_np = audio_waveform_tensor.squeeze() else: raise TypeError(f"Vocoder output type not recognized: {type(audio_waveform_tensor)}") # Ensure it's float32 for audio component, and normalize if necessary if audio_waveform_np.dtype != np.float32: audio_waveform_np = audio_waveform_np.astype(np.float32) # Optional: Normalize audio to [-1, 1] if not already # max_val = np.max(np.abs(audio_waveform_np)) # if max_val > 0: # audio_waveform_np = audio_waveform_np / max_val # Create Mel Spectrogram Plot mel_cpu = mel.squeeze(0).detach().cpu().numpy() # (T_mel_frames, Mel_channels) fig = plt.figure(figsize=(10, 4)) plt.imshow(mel_cpu.T, origin='lower', aspect='auto', interpolation='nearest') plt.xlabel('Time Frame') plt.ylabel('Mel Channel') plt.title(f'Mel Spectrogram: {genre_name_selected}, {duration_seconds}s') plt.colorbar(label='Amplitude') plt.tight_layout() # plt.close(fig) # Close the figure to free memory if not returning fig object directly status_message = f"Generated {duration_seconds}s of {genre_name_selected} music." print(status_message) return fig, (SAMPLING_RATE, audio_waveform_np), status_message except Exception as e: print(f"Error during generation: {e}") import traceback traceback.print_exc() # Return empty plot/audio on error fig, ax = plt.subplots() ax.text(0.5, 0.5, f"Error: {e}", ha='center', va='center', wrap=True) empty_audio = np.zeros(SAMPLING_RATE) return fig, (SAMPLING_RATE, empty_audio), f"Error: {e}" # --- Gradio UI Definition --- iface = gr.Interface( fn=generate_audio_and_plot, inputs=[ gr.Dropdown(choices=genre_name_list, value=genre_name_list[0] if genre_name_list else None, label="Genre"), gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Duration (seconds)"), gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.90, label="Temperature"), # User example used 0.95 gr.Slider(minimum=0.5, maximum=1.0, step=0.01, value=0.95, label="Top-P Sampling") # User example used 0.95 ], outputs=[ gr.Plot(label="Mel Spectrogram"), gr.Audio(label="Generated Audio", type="numpy"), # type="numpy" expects (sr, data) gr.Textbox(label="Status") ], title="MusicLSTM + MQGAN Music Generator sample", description="Genre-conditioned music generation with MusicLSTM. Select a genre, duration, and sampling parameters. Note that outputs may be nonsense depending on genre and luck.", allow_flagging="never" ) if __name__ == "__main__": if not MODELS_LOADED: print("\n--- WARNING: MODELS DID NOT LOAD ---") print("The Gradio app will launch, but generation will fail.") print("Please check model paths, MQGAN class definitions/imports, and any errors above.") # Create a dummy genre_ids.csv if it doesn't exist for the app to run if not os.path.exists(GENRE_CSV_PATH): print(f"Creating a dummy '{GENRE_CSV_PATH}' for demonstration.") with open(GENRE_CSV_PATH, "w", newline="") as f: writer = csv.writer(f) writer.writerow(["genre_id", "genre_name"]) writer.writerow([0, "60s"]) writer.writerow([176, "Rock"]) # From user's example writer.writerow([48, "Electronic"]) # From user's example writer.writerow([10, "Ambient"]) writer.writerow([20, "Classical"]) # Reload genres if we just created the file genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH) # Update dropdown choices in the interface if it was already defined (it is) # This is a bit hacky; ideally, interface definition is deferred or choices are dynamic. # For this script structure, re-assigning inputs might be tricky. # A cleaner way would be to define inputs within a function called by iface.launch() # or use gr.update in a more complex setup. # For simplicity, we rely on the initial load or the dummy file creation before iface. # If the file was created, the dropdown will use the new list if this script is re-run. # Or, update the component directly if Gradio allows after definition (less common for choices) # For now, this means if the CSV is missing, it's created, and on the *next* run it's used. # To make it immediate, we could redefine iface or use dynamic updates. # The current load_genres handles fallback well enough. iface.launch(share=True)