Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import csv | |
| import os | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.abspath(__file__))) | |
| # --- Configuration & Constants --- | |
| MODEL_BASE_PATH = "MQGAN/models" | |
| PREENC_PATH = os.path.join(MODEL_BASE_PATH, "preenc_musicv3.pth") | |
| MUSICLSTM_PATH = os.path.join(MODEL_BASE_PATH, "musiclstm_1.pt") | |
| ISTFTNET_PATH = os.path.join(MODEL_BASE_PATH, "istftnet_music_ft") | |
| GENRE_CSV_PATH = os.path.join(MODEL_BASE_PATH, "genre_ids.csv") | |
| VOCAB_SIZE = 1010 | |
| NUM_GENRES = 250 # This might be derived from genre_ids.csv or fixed if model expects it | |
| BOS_ID = 1001 | |
| PAD_ID = 1002 | |
| TOKENS_PER_SECOND = 86 | |
| SAMPLING_RATE = 44100 | |
| # Use Agg backend for Matplotlib in headless environments (like Gradio) | |
| matplotlib.use('Agg') | |
| # --- Attempt to import actual MQGAN modules --- | |
| from MQGAN import get_pre_encoder as actual_get_pre_encoder | |
| from MQGAN import ISTFTNetFE as actual_ISTFTNetFE | |
| from MQGAN import MusicLSTM as actual_MusicLSTM | |
| # If imports succeed, replace placeholders | |
| get_pre_encoder = actual_get_pre_encoder | |
| ISTFTNetFE = actual_ISTFTNetFE | |
| MusicLSTM = actual_MusicLSTM | |
| # --- Global Model Loading --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| # Check if model paths exist | |
| if not os.path.exists(PREENC_PATH): | |
| print(f"Warning: Pre-encoder model not found at {PREENC_PATH}") | |
| if not os.path.exists(MUSICLSTM_PATH): | |
| print(f"Warning: MusicLSTM model not found at {MUSICLSTM_PATH}") | |
| if not (os.path.exists(ISTFTNET_PATH + ".ts") or os.path.exists(ISTFTNET_PATH + ".pth")): # .ts or .pth for vocoder | |
| print(f"Warning: Vocoder model not found at {ISTFTNET_PATH}. (.ts or .pth)") | |
| pre_enc = get_pre_encoder(PREENC_PATH, device, channels=[192, 768, 1024, 1024], kernel_sizes=[3, 5, 7, 11], | |
| mel_channels=160) | |
| # Determine NUM_GENRES from genre_ids.csv if possible, otherwise use fixed | |
| # For now, using the fixed NUM_GENRES. The model needs to be trained with this. | |
| model = MusicLSTM(vocab_size=VOCAB_SIZE, num_genres=NUM_GENRES, pad_id=PAD_ID) | |
| chkp = torch.load(MUSICLSTM_PATH, map_location=device, weights_only=False) | |
| model.load_state_dict(chkp["model_state_dict"]) | |
| model = model.to(device).eval() | |
| vocoder = ISTFTNetFE(None, None) # Adjust constructor if needed | |
| vocoder.load_ts(ISTFTNET_PATH, device) # load_ts might expect a directory/prefix | |
| vocoder = vocoder.to(device).eval() | |
| MODELS_LOADED = True | |
| print("Models loaded successfully (or placeholders initialized).") | |
| # --- Genre Loading --- | |
| def load_genres(csv_filepath): | |
| genre_map = {} | |
| genre_names = [] | |
| if not os.path.exists(csv_filepath): | |
| print(f"Warning: Genre CSV file not found at {csv_filepath}. Using default genres.") | |
| # Provide a few default genres if CSV is missing | |
| default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50} # Example | |
| genre_map = default_genres | |
| genre_names = list(default_genres.keys()) | |
| return genre_map, genre_names | |
| try: | |
| with open(csv_filepath, mode='r', encoding='utf-8') as infile: | |
| reader = csv.reader(infile) | |
| next(reader) # Skip header | |
| for row in reader: | |
| if len(row) == 2: | |
| genre_id, genre_name = int(row[0]), row[1] | |
| genre_map[genre_name] = genre_id | |
| genre_names.append(genre_name) | |
| # Update NUM_GENRES based on loaded genres if model allows dynamic genre count | |
| # global NUM_GENRES | |
| # NUM_GENRES = len(genre_map) # This could be problematic if model has fixed NUM_GENRES | |
| print(f"Loaded {len(genre_names)} genres from {csv_filepath}") | |
| except Exception as e: | |
| print(f"Error loading genres from {csv_filepath}: {e}. Using default genres.") | |
| default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50} | |
| genre_map = default_genres | |
| genre_names = list(default_genres.keys()) | |
| return genre_map, genre_names | |
| genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH) | |
| if not genre_name_list: # Fallback if loading fails completely | |
| genre_name_list = ["Rock"] # Default if CSV is empty or unreadable | |
| genre_to_id_map = {"Rock": 176} | |
| # --- Helper Functions (from user code) --- | |
| def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9, filter_value: float = -float('Inf')) -> torch.Tensor: | |
| if top_p is None or top_p == 1.0: # Added top_p == 1.0 condition | |
| return logits | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| sorted_probs = F.softmax(sorted_logits, dim=-1) | |
| cumulative_probs = sorted_probs.cumsum(dim=-1) | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = False | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| # Ensure logits is 2D for scatter_ | |
| if logits.dim() == 1: | |
| logits = logits.unsqueeze(0) | |
| # Ensure indices_to_remove is also 2D and matches logits batch size for scatter_ | |
| # Or, if logits was originally 1D, apply to that. | |
| # The original code assumes logits is (batch_size, vocab_size) | |
| # and indices_to_remove is derived from that. | |
| # If logits is (1, vocab_size), indices_to_remove will be correctly shaped by sorted_indices. | |
| # Simpler way to apply filter for batch_size=1 case: | |
| logits_clone = logits.clone() | |
| if logits_clone.shape[0] == 1 and indices_to_remove.numel() > 0: # Check if there's anything to remove | |
| logits_clone[0, indices_to_remove] = filter_value | |
| elif logits_clone.shape[0] > 1: # For batch_size > 1, scatter needs careful handling | |
| # This part might need adjustment if batch_size > 1 is used in generation. | |
| # The original code implies batch_size = 1 for generation. | |
| for i in range(logits_clone.shape[0]): | |
| # This assumes sorted_indices_to_remove was calculated per batch item if logits was batched | |
| # However, top_p_filtering as provided seems to operate on a single logit distribution (batch_size=1) | |
| # If logits is (batch_size, vocab_size), then sorted_indices will be (batch_size, vocab_size) | |
| # and indices_to_remove would be (batch_size, num_items_to_remove_per_batch_item) | |
| # The current implementation of top_p seems to assume a single batch item or applies the same mask. | |
| # For safety, let's assume it works on (1, vocab_size) or (vocab_size) | |
| current_indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] | |
| logits_clone[i, current_indices_to_remove] = filter_value | |
| return logits_clone | |
| def generate_music_tokens( | |
| music_model, # Renamed from 'model' to avoid conflict with global 'model' | |
| initial_bos: int, | |
| genre_id: int, | |
| n_tokens_to_generate: int, | |
| temperature: float = 0.75, | |
| top_p: float = 0.9, | |
| current_device: torch.device = None, # Renamed from 'device' | |
| ) -> torch.LongTensor: | |
| if current_device is None: | |
| current_device = next(music_model.parameters()).device | |
| music_model.eval() | |
| num_layers = music_model.lstm.num_layers | |
| hidden_size = music_model.lstm.hidden_size | |
| h = torch.zeros(num_layers, 1, hidden_size, device=current_device) | |
| c = torch.zeros(num_layers, 1, hidden_size, device=current_device) | |
| genre_tensor = torch.tensor([genre_id], device=current_device, dtype=torch.long) | |
| # Check if genre_id is within the valid range for the embedding layer | |
| if not (0 <= genre_id < music_model.genre_emb.num_embeddings): | |
| raise ValueError(f"genre_id {genre_id} is out of bounds for genre_emb layer " | |
| f"with {music_model.genre_emb.num_embeddings} embeddings.") | |
| genre_emb = music_model.genre_emb(genre_tensor) | |
| token = initial_bos | |
| generated = [token] | |
| for _ in range(n_tokens_to_generate): | |
| tok_e = music_model.tok_emb(torch.tensor([token], device=current_device)) | |
| inp = (tok_e + genre_emb).unsqueeze(1) # LSTM expects (seq_len, batch, features) -> (1,1,emb_dim) | |
| out, (h, c) = music_model.lstm(inp, (h, c)) | |
| out = music_model.dropout(out.squeeze(1)) # (1, hidden_size) | |
| logits = music_model.proj(out) / max(temperature, 1e-8) # (1, vocab_size) | |
| filtered_logits = top_p_filtering(logits, top_p=top_p) | |
| probs = F.softmax(filtered_logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1).item() | |
| generated.append(next_token) | |
| token = next_token | |
| return torch.tensor([generated], device=current_device, dtype=torch.long) | |
| # --- Main Gradio Inference Function --- | |
| def generate_audio_and_plot(genre_name_selected, duration_seconds, temperature, top_p_val): | |
| if not MODELS_LOADED or not all([pre_enc, model, vocoder]): | |
| # Create a dummy plot and audio to indicate failure | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, "Models not loaded. Cannot generate.", ha='center', va='center') | |
| dummy_audio = np.zeros(SAMPLING_RATE) # 1 second of silence | |
| return fig, (SAMPLING_RATE, dummy_audio), "Error: Models not loaded. Check console." | |
| try: | |
| genre_id_selected = genre_to_id_map.get(genre_name_selected) | |
| if genre_id_selected is None: | |
| # Fallback if genre name not in map (should not happen with dropdown) | |
| return None, None, f"Error: Genre '{genre_name_selected}' not found." | |
| n_tokens = int(TOKENS_PER_SECOND * duration_seconds) | |
| print( | |
| f"Generating for Genre: {genre_name_selected} (ID: {genre_id_selected}), Duration: {duration_seconds}s, Temp: {temperature}, Top-P: {top_p_val}") | |
| indices = generate_music_tokens( | |
| model, BOS_ID, genre_id_selected, n_tokens, temperature, top_p_val, device | |
| ) # (1, T_tokens) | |
| mel = pre_enc.decode(indices) # Expected: (1, T_mel_frames, Mel_channels) | |
| # Transpose mel for vocoder if needed. User code has mel.transpose(1,2) | |
| # If pre_enc.decode gives (1, T_frames, Mel_channels) and vocoder expects (1, Mel_channels, T_frames) | |
| mel_for_vocoder = mel.transpose(1, 2) # (1, Mel_channels, T_mel_frames) | |
| # Vocoder inference | |
| # The output of vocoder.infer should be a 1D numpy array or torch tensor (waveform) | |
| audio_waveform_tensor = vocoder.infer(mel_for_vocoder) # This is a placeholder output | |
| if isinstance(audio_waveform_tensor, torch.Tensor): | |
| audio_waveform_np = audio_waveform_tensor.squeeze().cpu().numpy() | |
| elif isinstance(audio_waveform_tensor, np.ndarray): | |
| audio_waveform_np = audio_waveform_tensor.squeeze() | |
| else: | |
| raise TypeError(f"Vocoder output type not recognized: {type(audio_waveform_tensor)}") | |
| # Ensure it's float32 for audio component, and normalize if necessary | |
| if audio_waveform_np.dtype != np.float32: | |
| audio_waveform_np = audio_waveform_np.astype(np.float32) | |
| # Optional: Normalize audio to [-1, 1] if not already | |
| # max_val = np.max(np.abs(audio_waveform_np)) | |
| # if max_val > 0: | |
| # audio_waveform_np = audio_waveform_np / max_val | |
| # Create Mel Spectrogram Plot | |
| mel_cpu = mel.squeeze(0).detach().cpu().numpy() # (T_mel_frames, Mel_channels) | |
| fig = plt.figure(figsize=(10, 4)) | |
| plt.imshow(mel_cpu.T, origin='lower', aspect='auto', interpolation='nearest') | |
| plt.xlabel('Time Frame') | |
| plt.ylabel('Mel Channel') | |
| plt.title(f'Mel Spectrogram: {genre_name_selected}, {duration_seconds}s') | |
| plt.colorbar(label='Amplitude') | |
| plt.tight_layout() | |
| # plt.close(fig) # Close the figure to free memory if not returning fig object directly | |
| status_message = f"Generated {duration_seconds}s of {genre_name_selected} music." | |
| print(status_message) | |
| return fig, (SAMPLING_RATE, audio_waveform_np), status_message | |
| except Exception as e: | |
| print(f"Error during generation: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # Return empty plot/audio on error | |
| fig, ax = plt.subplots() | |
| ax.text(0.5, 0.5, f"Error: {e}", ha='center', va='center', wrap=True) | |
| empty_audio = np.zeros(SAMPLING_RATE) | |
| return fig, (SAMPLING_RATE, empty_audio), f"Error: {e}" | |
| # --- Gradio UI Definition --- | |
| iface = gr.Interface( | |
| fn=generate_audio_and_plot, | |
| inputs=[ | |
| gr.Dropdown(choices=genre_name_list, value=genre_name_list[0] if genre_name_list else None, label="Genre"), | |
| gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Duration (seconds)"), | |
| gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.90, label="Temperature"), # User example used 0.95 | |
| gr.Slider(minimum=0.5, maximum=1.0, step=0.01, value=0.95, label="Top-P Sampling") # User example used 0.95 | |
| ], | |
| outputs=[ | |
| gr.Plot(label="Mel Spectrogram"), | |
| gr.Audio(label="Generated Audio", type="numpy"), # type="numpy" expects (sr, data) | |
| gr.Textbox(label="Status") | |
| ], | |
| title="MusicLSTM + MQGAN Music Generator sample", | |
| description="Genre-conditioned music generation with MusicLSTM. Select a genre, duration, and sampling parameters. Note that outputs may be nonsense depending on genre and luck.", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| if not MODELS_LOADED: | |
| print("\n--- WARNING: MODELS DID NOT LOAD ---") | |
| print("The Gradio app will launch, but generation will fail.") | |
| print("Please check model paths, MQGAN class definitions/imports, and any errors above.") | |
| # Create a dummy genre_ids.csv if it doesn't exist for the app to run | |
| if not os.path.exists(GENRE_CSV_PATH): | |
| print(f"Creating a dummy '{GENRE_CSV_PATH}' for demonstration.") | |
| with open(GENRE_CSV_PATH, "w", newline="") as f: | |
| writer = csv.writer(f) | |
| writer.writerow(["genre_id", "genre_name"]) | |
| writer.writerow([0, "60s"]) | |
| writer.writerow([176, "Rock"]) # From user's example | |
| writer.writerow([48, "Electronic"]) # From user's example | |
| writer.writerow([10, "Ambient"]) | |
| writer.writerow([20, "Classical"]) | |
| # Reload genres if we just created the file | |
| genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH) | |
| # Update dropdown choices in the interface if it was already defined (it is) | |
| # This is a bit hacky; ideally, interface definition is deferred or choices are dynamic. | |
| # For this script structure, re-assigning inputs might be tricky. | |
| # A cleaner way would be to define inputs within a function called by iface.launch() | |
| # or use gr.update in a more complex setup. | |
| # For simplicity, we rely on the initial load or the dummy file creation before iface. | |
| # If the file was created, the dropdown will use the new list if this script is re-run. | |
| # Or, update the component directly if Gradio allows after definition (less common for choices) | |
| # For now, this means if the CSV is missing, it's created, and on the *next* run it's used. | |
| # To make it immediate, we could redefine iface or use dynamic updates. | |
| # The current load_genres handles fallback well enough. | |
| iface.launch(share=True) | |