File size: 15,463 Bytes
827b824
 
 
 
 
 
 
 
41f5dcd
 
a1d5625
41f5dcd
827b824
 
a6e329e
827b824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a3fac8
 
 
 
 
 
 
 
827b824
 
 
 
 
 
 
 
 
 
 
 
 
66d9b0c
 
827b824
 
66d9b0c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827b824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
import gradio as gr
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import csv
import os
import sys

sys.path.append(os.path.dirname(os.path.abspath(__file__)))


# --- Configuration & Constants ---
MODEL_BASE_PATH = "MQGAN/models"
PREENC_PATH = os.path.join(MODEL_BASE_PATH, "preenc_musicv3.pth")
MUSICLSTM_PATH = os.path.join(MODEL_BASE_PATH, "musiclstm_1.pt")
ISTFTNET_PATH = os.path.join(MODEL_BASE_PATH, "istftnet_music_ft")
GENRE_CSV_PATH = os.path.join(MODEL_BASE_PATH, "genre_ids.csv")

VOCAB_SIZE = 1010
NUM_GENRES = 250  # This might be derived from genre_ids.csv or fixed if model expects it
BOS_ID = 1001
PAD_ID = 1002
TOKENS_PER_SECOND = 86
SAMPLING_RATE = 44100

# Use Agg backend for Matplotlib in headless environments (like Gradio)
matplotlib.use('Agg')

# --- Attempt to import actual MQGAN modules ---
from MQGAN import get_pre_encoder as actual_get_pre_encoder
from MQGAN import ISTFTNetFE as actual_ISTFTNetFE
from MQGAN import MusicLSTM as actual_MusicLSTM

# If imports succeed, replace placeholders
get_pre_encoder = actual_get_pre_encoder
ISTFTNetFE = actual_ISTFTNetFE
MusicLSTM = actual_MusicLSTM

# --- Global Model Loading ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check if model paths exist
if not os.path.exists(PREENC_PATH):
    print(f"Warning: Pre-encoder model not found at {PREENC_PATH}")
if not os.path.exists(MUSICLSTM_PATH):
    print(f"Warning: MusicLSTM model not found at {MUSICLSTM_PATH}")
if not (os.path.exists(ISTFTNET_PATH + ".ts") or os.path.exists(ISTFTNET_PATH + ".pth")):  # .ts or .pth for vocoder
    print(f"Warning: Vocoder model not found at {ISTFTNET_PATH}. (.ts or .pth)")


pre_enc = get_pre_encoder(PREENC_PATH, device, channels=[192, 768, 1024, 1024], kernel_sizes=[3, 5, 7, 11],
                              mel_channels=160)

# Determine NUM_GENRES from genre_ids.csv if possible, otherwise use fixed
# For now, using the fixed NUM_GENRES. The model needs to be trained with this.
model = MusicLSTM(vocab_size=VOCAB_SIZE, num_genres=NUM_GENRES, pad_id=PAD_ID)
    
chkp = torch.load(MUSICLSTM_PATH, map_location=device, weights_only=False) 
model.load_state_dict(chkp["model_state_dict"])

model = model.to(device).eval()

vocoder = ISTFTNetFE(None, None)  # Adjust constructor if needed

vocoder.load_ts(ISTFTNET_PATH, device)  # load_ts might expect a directory/prefix
vocoder = vocoder.to(device).eval()

MODELS_LOADED = True
print("Models loaded successfully (or placeholders initialized).")




# --- Genre Loading ---
def load_genres(csv_filepath):
    genre_map = {}
    genre_names = []
    if not os.path.exists(csv_filepath):
        print(f"Warning: Genre CSV file not found at {csv_filepath}. Using default genres.")
        # Provide a few default genres if CSV is missing
        default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50}  # Example
        genre_map = default_genres
        genre_names = list(default_genres.keys())
        return genre_map, genre_names

    try:
        with open(csv_filepath, mode='r', encoding='utf-8') as infile:
            reader = csv.reader(infile)
            next(reader)  # Skip header
            for row in reader:
                if len(row) == 2:
                    genre_id, genre_name = int(row[0]), row[1]
                    genre_map[genre_name] = genre_id
                    genre_names.append(genre_name)
        # Update NUM_GENRES based on loaded genres if model allows dynamic genre count
        # global NUM_GENRES
        # NUM_GENRES = len(genre_map) # This could be problematic if model has fixed NUM_GENRES
        print(f"Loaded {len(genre_names)} genres from {csv_filepath}")
    except Exception as e:
        print(f"Error loading genres from {csv_filepath}: {e}. Using default genres.")
        default_genres = {"Rock": 176, "Pop": 100, "Electronic": 50}
        genre_map = default_genres
        genre_names = list(default_genres.keys())
    return genre_map, genre_names


genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH)
if not genre_name_list:  # Fallback if loading fails completely
    genre_name_list = ["Rock"]  # Default if CSV is empty or unreadable
    genre_to_id_map = {"Rock": 176}


# --- Helper Functions (from user code) ---
def top_p_filtering(logits: torch.Tensor, top_p: float = 0.9, filter_value: float = -float('Inf')) -> torch.Tensor:
    if top_p is None or top_p == 1.0:  # Added top_p == 1.0 condition
        return logits
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    sorted_probs = F.softmax(sorted_logits, dim=-1)
    cumulative_probs = sorted_probs.cumsum(dim=-1)
    sorted_indices_to_remove = cumulative_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = False
    indices_to_remove = sorted_indices[sorted_indices_to_remove]
    # Ensure logits is 2D for scatter_
    if logits.dim() == 1:
        logits = logits.unsqueeze(0)
    # Ensure indices_to_remove is also 2D and matches logits batch size for scatter_
    # Or, if logits was originally 1D, apply to that.
    # The original code assumes logits is (batch_size, vocab_size)
    # and indices_to_remove is derived from that.
    # If logits is (1, vocab_size), indices_to_remove will be correctly shaped by sorted_indices.

    # Simpler way to apply filter for batch_size=1 case:
    logits_clone = logits.clone()
    if logits_clone.shape[0] == 1 and indices_to_remove.numel() > 0:  # Check if there's anything to remove
        logits_clone[0, indices_to_remove] = filter_value
    elif logits_clone.shape[0] > 1:  # For batch_size > 1, scatter needs careful handling
        # This part might need adjustment if batch_size > 1 is used in generation.
        # The original code implies batch_size = 1 for generation.
        for i in range(logits_clone.shape[0]):
            # This assumes sorted_indices_to_remove was calculated per batch item if logits was batched
            # However, top_p_filtering as provided seems to operate on a single logit distribution (batch_size=1)
            # If logits is (batch_size, vocab_size), then sorted_indices will be (batch_size, vocab_size)
            # and indices_to_remove would be (batch_size, num_items_to_remove_per_batch_item)
            # The current implementation of top_p seems to assume a single batch item or applies the same mask.
            # For safety, let's assume it works on (1, vocab_size) or (vocab_size)
            current_indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            logits_clone[i, current_indices_to_remove] = filter_value
    return logits_clone


@torch.no_grad()
def generate_music_tokens(
        music_model,  # Renamed from 'model' to avoid conflict with global 'model'
        initial_bos: int,
        genre_id: int,
        n_tokens_to_generate: int,
        temperature: float = 0.75,
        top_p: float = 0.9,
        current_device: torch.device = None,  # Renamed from 'device'
) -> torch.LongTensor:
    if current_device is None:
        current_device = next(music_model.parameters()).device
    music_model.eval()

    num_layers = music_model.lstm.num_layers
    hidden_size = music_model.lstm.hidden_size
    h = torch.zeros(num_layers, 1, hidden_size, device=current_device)
    c = torch.zeros(num_layers, 1, hidden_size, device=current_device)

    genre_tensor = torch.tensor([genre_id], device=current_device, dtype=torch.long)

    # Check if genre_id is within the valid range for the embedding layer
    if not (0 <= genre_id < music_model.genre_emb.num_embeddings):
        raise ValueError(f"genre_id {genre_id} is out of bounds for genre_emb layer "
                         f"with {music_model.genre_emb.num_embeddings} embeddings.")

    genre_emb = music_model.genre_emb(genre_tensor)

    token = initial_bos
    generated = [token]

    for _ in range(n_tokens_to_generate):
        tok_e = music_model.tok_emb(torch.tensor([token], device=current_device))
        inp = (tok_e + genre_emb).unsqueeze(1)  # LSTM expects (seq_len, batch, features) -> (1,1,emb_dim)

        out, (h, c) = music_model.lstm(inp, (h, c))
        out = music_model.dropout(out.squeeze(1))  # (1, hidden_size)

        logits = music_model.proj(out) / max(temperature, 1e-8)  # (1, vocab_size)
        filtered_logits = top_p_filtering(logits, top_p=top_p)
        probs = F.softmax(filtered_logits, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_token)
        token = next_token

    return torch.tensor([generated], device=current_device, dtype=torch.long)


# --- Main Gradio Inference Function ---
def generate_audio_and_plot(genre_name_selected, duration_seconds, temperature, top_p_val):
    if not MODELS_LOADED or not all([pre_enc, model, vocoder]):
        # Create a dummy plot and audio to indicate failure
        fig, ax = plt.subplots()
        ax.text(0.5, 0.5, "Models not loaded. Cannot generate.", ha='center', va='center')
        dummy_audio = np.zeros(SAMPLING_RATE)  # 1 second of silence
        return fig, (SAMPLING_RATE, dummy_audio), "Error: Models not loaded. Check console."

    try:
        genre_id_selected = genre_to_id_map.get(genre_name_selected)
        if genre_id_selected is None:
            # Fallback if genre name not in map (should not happen with dropdown)
            return None, None, f"Error: Genre '{genre_name_selected}' not found."

        n_tokens = int(TOKENS_PER_SECOND * duration_seconds)

        print(
            f"Generating for Genre: {genre_name_selected} (ID: {genre_id_selected}), Duration: {duration_seconds}s, Temp: {temperature}, Top-P: {top_p_val}")

        indices = generate_music_tokens(
            model, BOS_ID, genre_id_selected, n_tokens, temperature, top_p_val, device
        )  # (1, T_tokens)

        mel = pre_enc.decode(indices)  # Expected: (1, T_mel_frames, Mel_channels)

        # Transpose mel for vocoder if needed. User code has mel.transpose(1,2)
        # If pre_enc.decode gives (1, T_frames, Mel_channels) and vocoder expects (1, Mel_channels, T_frames)
        mel_for_vocoder = mel.transpose(1, 2)  # (1, Mel_channels, T_mel_frames)

        # Vocoder inference
        # The output of vocoder.infer should be a 1D numpy array or torch tensor (waveform)
        audio_waveform_tensor = vocoder.infer(mel_for_vocoder)  # This is a placeholder output

        if isinstance(audio_waveform_tensor, torch.Tensor):
            audio_waveform_np = audio_waveform_tensor.squeeze().cpu().numpy()
        elif isinstance(audio_waveform_tensor, np.ndarray):
            audio_waveform_np = audio_waveform_tensor.squeeze()
        else:
            raise TypeError(f"Vocoder output type not recognized: {type(audio_waveform_tensor)}")

        # Ensure it's float32 for audio component, and normalize if necessary
        if audio_waveform_np.dtype != np.float32:
            audio_waveform_np = audio_waveform_np.astype(np.float32)

        # Optional: Normalize audio to [-1, 1] if not already
        # max_val = np.max(np.abs(audio_waveform_np))
        # if max_val > 0:
        #    audio_waveform_np = audio_waveform_np / max_val

        # Create Mel Spectrogram Plot
        mel_cpu = mel.squeeze(0).detach().cpu().numpy()  # (T_mel_frames, Mel_channels)

        fig = plt.figure(figsize=(10, 4))
        plt.imshow(mel_cpu.T, origin='lower', aspect='auto', interpolation='nearest')
        plt.xlabel('Time Frame')
        plt.ylabel('Mel Channel')
        plt.title(f'Mel Spectrogram: {genre_name_selected}, {duration_seconds}s')
        plt.colorbar(label='Amplitude')
        plt.tight_layout()
        # plt.close(fig) # Close the figure to free memory if not returning fig object directly

        status_message = f"Generated {duration_seconds}s of {genre_name_selected} music."
        print(status_message)

        return fig, (SAMPLING_RATE, audio_waveform_np), status_message

    except Exception as e:
        print(f"Error during generation: {e}")
        import traceback
        traceback.print_exc()
        # Return empty plot/audio on error
        fig, ax = plt.subplots()
        ax.text(0.5, 0.5, f"Error: {e}", ha='center', va='center', wrap=True)
        empty_audio = np.zeros(SAMPLING_RATE)
        return fig, (SAMPLING_RATE, empty_audio), f"Error: {e}"


# --- Gradio UI Definition ---
iface = gr.Interface(
    fn=generate_audio_and_plot,
    inputs=[
        gr.Dropdown(choices=genre_name_list, value=genre_name_list[0] if genre_name_list else None, label="Genre"),
        gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Duration (seconds)"),
        gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.90, label="Temperature"),  # User example used 0.95
        gr.Slider(minimum=0.5, maximum=1.0, step=0.01, value=0.95, label="Top-P Sampling")  # User example used 0.95
    ],
    outputs=[
        gr.Plot(label="Mel Spectrogram"),
        gr.Audio(label="Generated Audio", type="numpy"),  # type="numpy" expects (sr, data)
        gr.Textbox(label="Status")
    ],
    title="MusicLSTM + MQGAN Music Generator sample",
    description="Genre-conditioned music generation with MusicLSTM. Select a genre, duration, and sampling parameters. Note that outputs may be nonsense depending on genre and luck.",
    allow_flagging="never"
)

if __name__ == "__main__":
    if not MODELS_LOADED:
        print("\n--- WARNING: MODELS DID NOT LOAD ---")
        print("The Gradio app will launch, but generation will fail.")
        print("Please check model paths, MQGAN class definitions/imports, and any errors above.")

    # Create a dummy genre_ids.csv if it doesn't exist for the app to run
    if not os.path.exists(GENRE_CSV_PATH):
        print(f"Creating a dummy '{GENRE_CSV_PATH}' for demonstration.")
        with open(GENRE_CSV_PATH, "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["genre_id", "genre_name"])
            writer.writerow([0, "60s"])
            writer.writerow([176, "Rock"])  # From user's example
            writer.writerow([48, "Electronic"])  # From user's example
            writer.writerow([10, "Ambient"])
            writer.writerow([20, "Classical"])
        # Reload genres if we just created the file
        genre_to_id_map, genre_name_list = load_genres(GENRE_CSV_PATH)
        # Update dropdown choices in the interface if it was already defined (it is)
        # This is a bit hacky; ideally, interface definition is deferred or choices are dynamic.
        # For this script structure, re-assigning inputs might be tricky.
        # A cleaner way would be to define inputs within a function called by iface.launch()
        # or use gr.update in a more complex setup.
        # For simplicity, we rely on the initial load or the dummy file creation before iface.
        # If the file was created, the dropdown will use the new list if this script is re-run.
        # Or, update the component directly if Gradio allows after definition (less common for choices)
        # For now, this means if the CSV is missing, it's created, and on the *next* run it's used.
        # To make it immediate, we could redefine iface or use dynamic updates.
        # The current load_genres handles fallback well enough.

    iface.launch(share=True)