lovelyai999 commited on
Commit
2ea45f5
·
verified ·
1 Parent(s): e81756a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1395 -0
app.py CHANGED
@@ -0,0 +1,1395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time, random
4
+ from random import choice
5
+ from typing import List, Dict
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import music21
10
+ import numpy as np
11
+ from sklearn.preprocessing import MultiLabelBinarizer
12
+ from tqdm import tqdm
13
+ import wave
14
+ import struct
15
+ import ffmpeg
16
+ import tempfile
17
+ from pydub import AudioSegment
18
+ from moviepy.editor import VideoFileClip, AudioFileClip
19
+ from torch.utils.data import Dataset, DataLoader
20
+ from torch.utils.data import DataLoader, SubsetRandomSampler
21
+ from torch.utils.tensorboard import SummaryWriter
22
+ from torch.nn.utils import clip_grad_norm_
23
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
24
+ from torch.utils.tensorboard import SummaryWriter
25
+
26
+ # 设置基础路径
27
+ Gbase = "./"
28
+ cache_dir = "./hf/"
29
+
30
+ try:
31
+ import google.colab
32
+ from google.colab import drive
33
+
34
+ IN_COLAB = True
35
+ drive.mount('/gdrive', force_remount=True)
36
+ Gbase = "/gdrive/MyDrive/generate/"
37
+ cache_dir = "/gdrive/MyDrive/hf/"
38
+ sys.path.append(Gbase)
39
+ except:
40
+ IN_COLAB = False
41
+ Gbase = "./"
42
+ cache_dir = "./hf/"
43
+
44
+ # 定义模型保存路径
45
+ ModelPath = os.path.join(Gbase, 'music_generation_model.pth')
46
+ OptimizerPath = os.path.join(Gbase, 'optimizer_state.pth')
47
+ DiscriminatorModelPath = os.path.join(Gbase, 'discriminator_model.pth')
48
+ DiscriminatorOptimizerPath = os.path.join(Gbase, 'discriminator_optimizer_state.pth')
49
+ EvaluatorPath = os.path.join(Gbase, 'music_tag_evaluator.pkl')
50
+
51
+ # 定义音乐标签
52
+ MUSIC_TAGS = {
53
+ 'emotions': ['Happy', 'Sad', 'Angry', 'Peaceful', 'Neutral'],
54
+ 'genres': ['Classical', 'Jazz', 'Rock', 'Electronic'],
55
+ 'tempo': ['Slow', 'Medium', 'Fast'],
56
+ 'instrumentation': ['Piano', 'Guitar', 'Synthesizer'],
57
+ 'harmony': ['Consonant', 'Dissonant', 'Complex', 'Simple'],
58
+ 'dynamics': ['Dynamic', 'Static'],
59
+ 'rhythm': ['Simple', 'Complex']
60
+ }
61
+
62
+ def randomMusicTags():
63
+ return {k: choice(MUSIC_TAGS[k]) for k in MUSIC_TAGS.keys()}
64
+
65
+ print("随机生成的音乐标签:", randomMusicTags())
66
+
67
+ def get_scale_notes(key_str: str, octave_range=(2, 6)) -> List[int]:
68
+ """
69
+ 根据调性返回所属音阶的 MIDI 音高列表。
70
+ """
71
+ key = music21.key.Key(key_str)
72
+ scale_notes = []
73
+ for octave in range(octave_range[0], octave_range[1] + 1):
74
+ pitches = key.getScale().getPitches(f"{key_str}{octave}")
75
+ for pitch in pitches:
76
+ scale_notes.append(pitch.midi)
77
+ return scale_notes
78
+
79
+ def composer_from_features(features: np.ndarray, key_str: str) -> music21.stream.Stream:
80
+ """
81
+ 将特征转换为 music21.stream.Stream 对象,并确保音符遵循指定音阶。
82
+ """
83
+ s = music21.stream.Stream()
84
+
85
+ # 设置节奏(BPM),默认 120 BPM
86
+ tempo = music21.tempo.MetronomeMark(number=120)
87
+ s.append(tempo)
88
+
89
+ # 设置调性
90
+ tonality = music21.key.Key(key_str)
91
+ s.append(tonality)
92
+
93
+ # 获取音阶音符
94
+ scale_notes = get_scale_notes(key_str)
95
+
96
+ # 定义可接受的时值
97
+ acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
98
+
99
+ for feature in features:
100
+ pitch = int(round(feature[0]))
101
+ duration = feature[1]
102
+ volume = feature[2]
103
+
104
+ # 将时值量化为最近的可接受值
105
+ duration = min(acceptable_durations, key=lambda x: abs(x - duration))
106
+
107
+ # 确保音高在 21 (A0) 到 108 (C8) 之间
108
+ pitch = max(21, min(108, pitch))
109
+
110
+ # 将音高映射到最近的音阶音符
111
+ if pitch not in scale_notes:
112
+ pitch = min(scale_notes, key=lambda x: abs(x - pitch))
113
+
114
+ # 确保音量在 0 到 127 之间
115
+ volume = max(0, min(127, volume))
116
+
117
+ if pitch == 0:
118
+ # 休止符
119
+ r = music21.note.Rest(quarterLength=duration)
120
+ s.append(r)
121
+ else:
122
+ n = music21.note.Note(midi=pitch, quarterLength=duration)
123
+ n.volume.velocity = volume
124
+ s.append(n)
125
+ return s
126
+
127
+ import pickle
128
+
129
+ class MusicTagEvaluator:
130
+ def __init__(self):
131
+ # 定义所有标签
132
+ self.MUSIC_TAGS = MUSIC_TAGS
133
+ # 展平成所有标签并移除重复项
134
+ all_tags = []
135
+ for category in self.MUSIC_TAGS:
136
+ all_tags.extend(self.MUSIC_TAGS[category])
137
+ self.all_tags = list(set(all_tags)) # 移除重复的标签
138
+ self.mlb = MultiLabelBinarizer()
139
+ self.mlb.fit([self.all_tags])
140
+
141
+ def save(self, path):
142
+ with open(path, 'wb') as f:
143
+ pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
144
+ print(f"评估器已保存至 '{path}'。")
145
+
146
+ @staticmethod
147
+ def load(path):
148
+ if os.path.exists(path):
149
+ with open(path, 'rb') as f:
150
+ evaluator = pickle.load(f)
151
+ print(f"评估器已从 '{path}' 加载。")
152
+ return evaluator
153
+ else:
154
+ print(f"评估器文件 '{path}' 不存在,将创建新的评估器。")
155
+ return MusicTagEvaluator()
156
+
157
+ def evaluate_tags_from_features(self, features: np.ndarray) -> List[str]:
158
+ """
159
+ 根据特征评估标签。
160
+ """
161
+ # 随机选择一个调性以生成音乐
162
+ key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
163
+ s = composer_from_features(features, key_str)
164
+ tag_scores = self.evaluate_tags(s)
165
+ tags = []
166
+ # 根据评分分配标签
167
+ for category in self.MUSIC_TAGS:
168
+ tag = tag_scores.get(category)
169
+ if tag in self.MUSIC_TAGS[category]:
170
+ tags.append(tag)
171
+ return tags
172
+
173
+ def evaluate_tags(self, generated_music):
174
+ """
175
+ 根据生成的音乐评估标签。
176
+ """
177
+ tag_scores = {}
178
+
179
+ # 音高范围计算
180
+ pitch_values = [note.pitch.midi for note in generated_music.recurse().notes if isinstance(note, music21.note.Note)]
181
+ pitch_range = max(pitch_values) - min(pitch_values) if pitch_values else 0
182
+
183
+ # 单独评估各项
184
+ harmony_tag = self._evaluate_harmony(generated_music)
185
+ rhythm_tag = self._evaluate_rhythm(generated_music)
186
+ dynamics_tag = self._evaluate_dynamics(generated_music)
187
+ tempo_tag = self._evaluate_tempo(generated_music)
188
+ emotion_tag = self._evaluate_emotion(harmony_tag, rhythm_tag, dynamics_tag, tempo_tag)
189
+
190
+ # 标签集合
191
+ tag_scores['emotions'] = emotion_tag
192
+ tag_scores['harmony'] = harmony_tag
193
+ tag_scores['rhythm'] = rhythm_tag
194
+ tag_scores['dynamics'] = dynamics_tag
195
+ tag_scores['tempo'] = tempo_tag
196
+
197
+ return tag_scores
198
+
199
+ def _evaluate_harmony(self, stream):
200
+ # 将音乐流和弦化
201
+ chords = stream.chordify()
202
+ chord_types = []
203
+ for element in chords.recurse():
204
+ if isinstance(element, music21.chord.Chord):
205
+ chord_types.append(element.commonName)
206
+
207
+ # 根据和弦种类评估和声复杂度
208
+ if any('diminished' in str(ct) or 'augmented' in str(ct) for ct in chord_types):
209
+ harmony_tag = 'Complex'
210
+ elif any('major' in str(ct) or 'minor' in str(ct) for ct in chord_types):
211
+ harmony_tag = 'Consonant'
212
+ else:
213
+ harmony_tag = 'Simple'
214
+
215
+ return harmony_tag
216
+
217
+ def _evaluate_rhythm(self, stream):
218
+ durations = [note.quarterLength for note in stream.flat.notes]
219
+ # 计算节奏复杂度,如时值种类的数量
220
+ unique_durations = len(set(durations))
221
+
222
+ if unique_durations > 5:
223
+ rhythm_tag = 'Complex'
224
+ else:
225
+ rhythm_tag = 'Simple'
226
+
227
+ return rhythm_tag
228
+
229
+ def _evaluate_dynamics(self, stream):
230
+ volumes = [note.volume.velocity for note in stream.flat.notes if note.volume.velocity is not None]
231
+
232
+ if not volumes:
233
+ dynamics_tag = 'Static'
234
+ else:
235
+ dynamics_range = max(volumes) - min(volumes)
236
+ if dynamics_range > 40:
237
+ dynamics_tag = 'Dynamic'
238
+ else:
239
+ dynamics_tag = 'Static'
240
+
241
+ return dynamics_tag
242
+
243
+ def _evaluate_tempo(self, stream):
244
+ tempos = [metronome.number for metronome in stream.recurse() if isinstance(metronome, music21.tempo.MetronomeMark)]
245
+ bpm = tempos[0] if tempos else 120 # 默认 BPM 为 120
246
+
247
+ if bpm < 60:
248
+ return 'Slow'
249
+ elif 60 <= bpm < 120:
250
+ return 'Medium'
251
+ else:
252
+ return 'Fast'
253
+
254
+ def _evaluate_emotion(self, harmony_tag, rhythm_tag, dynamics_tag, tempo_tag):
255
+ # 根据和声、节奏、动态和节奏进行情感评估
256
+ if harmony_tag == 'Complex' and rhythm_tag == 'Complex':
257
+ emotion = 'Angry'
258
+ elif harmony_tag == 'Consonant' and dynamics_tag == 'Dynamic' and tempo_tag == 'Fast':
259
+ emotion = 'Happy'
260
+ elif harmony_tag == 'Simple' and dynamics_tag == 'Static' and tempo_tag == 'Slow':
261
+ emotion = 'Peaceful'
262
+ elif harmony_tag == 'Consonant' and dynamics_tag == 'Static' and tempo_tag == 'Medium':
263
+ emotion = 'Neutral'
264
+ else:
265
+ emotion = 'Sad'
266
+
267
+ return emotion
268
+
269
+ class PositionalEncoding(nn.Module):
270
+ def __init__(self, d_model, max_len=5000):
271
+ super(PositionalEncoding, self).__init__()
272
+
273
+ pe = torch.zeros(max_len, d_model) # [max_len, d_model]
274
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [max_len, 1]
275
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) # [d_model/2]
276
+
277
+ pe[:, 0::2] = torch.sin(position * div_term) # even indices
278
+ pe[:, 1::2] = torch.cos(position * div_term) # odd indices
279
+
280
+ pe = pe.unsqueeze(0) # [1, max_len, d_model]
281
+ self.register_buffer('pe', pe)
282
+
283
+ def forward(self, x):
284
+ """
285
+ x: [batch_size, seq_len, d_model]
286
+ """
287
+ x = x + self.pe[:, :x.size(1), :]
288
+ return x
289
+
290
+ class MusicGenerationModel(nn.Module):
291
+ def __init__(self, input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags, max_seq_length=500):
292
+ super(MusicGenerationModel, self).__init__()
293
+ self.d_model = d_model
294
+ self.input_linear = nn.Linear(input_dim, d_model)
295
+ self.positional_encoding = PositionalEncoding(d_model, max_len=max_seq_length)
296
+ encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=0.1)
297
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_encoder_layers)
298
+ self.fc_music = nn.Linear(d_model, output_dim)
299
+ self.fc_tags = nn.Linear(d_model, num_tags)
300
+ self.sigmoid = nn.Sigmoid()
301
+ self.dropout = nn.Dropout(0.1)
302
+
303
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
304
+ """
305
+ src: [batch_size, seq_len, input_dim]
306
+ """
307
+ src = self.input_linear(src) * np.sqrt(self.d_model) # [batch_size, seq_len, d_model]
308
+ src = self.positional_encoding(src) # [batch_size, seq_len, d_model]
309
+ src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
310
+ memory = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask) # [seq_len, batch_size, d_model]
311
+ memory = memory.transpose(0, 1) # [batch_size, seq_len, d_model]
312
+ memory = self.dropout(memory)
313
+ music_output = self.fc_music(memory) # [batch_size, seq_len, output_dim]
314
+ tag_probabilities = self.sigmoid(self.fc_tags(memory)) # [batch_size, seq_len, num_tags]
315
+ return music_output, tag_probabilities
316
+
317
+ class Discriminator(nn.Module):
318
+ def __init__(self, input_dim, d_model, nhead, num_layers, dim_feedforward):
319
+ super(Discriminator, self).__init__()
320
+ self.input_linear = nn.Linear(input_dim, d_model)
321
+ self.positional_encoding = PositionalEncoding(d_model)
322
+ encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward)
323
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)
324
+ self.fc = nn.Linear(d_model, 1)
325
+ self.sigmoid = nn.Sigmoid()
326
+ self.dropout = nn.Dropout(0.1)
327
+
328
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
329
+ src = self.input_linear(src) * np.sqrt(self.input_linear.out_features)
330
+ src = self.positional_encoding(src)
331
+ src = src.transpose(0, 1) # [seq_len, batch_size, d_model]
332
+ output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
333
+ output = output.transpose(0, 1) # [batch_size, seq_len, d_model]
334
+ output = self.dropout(output)
335
+ # 取序列最后一个时间步作为判断依据,也可以选择取平均或其他方式
336
+ output = self.fc(output[:, -1, :])
337
+ output = self.sigmoid(output)
338
+ return output
339
+
340
+ class MidiDataset(Dataset):
341
+ def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator):
342
+ self.max_length = max_length
343
+ self.dataset_path = dataset_path
344
+ self.evaluator = evaluator
345
+ # 检查数据集文件是否存在
346
+ if os.path.exists(self.dataset_path):
347
+ # 加载已预处理的数据集
348
+ print(f"从 '{self.dataset_path}' 加载数据集")
349
+ try:
350
+ saved_data = torch.load(self.dataset_path)
351
+ self.features = saved_data['features']
352
+ self.labels = saved_data['labels']
353
+ print(f"成功加载数据集,共有 {len(self.features)} 个样本。")
354
+ except Exception as e:
355
+ print(f"加载数据集时出错: {e}")
356
+ self._process_midi_files(midi_files)
357
+ else:
358
+ # 处理 MIDI 文件并保存数据集
359
+ self._process_midi_files(midi_files)
360
+
361
+ def __len__(self):
362
+ return len(self.features)
363
+
364
+ def getAug(self, idx):
365
+ feature = self.features[idx] # [seq_len, input_dim]
366
+ label = self.labels[idx] # [num_tags]
367
+ # 应用数据增强
368
+ feature_aug, label_aug =self._augment_data(feature, label)
369
+ # 返回张量
370
+ return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
371
+ def __getitem__(self, idx):
372
+ feature = self.features[idx] # [seq_len, input_dim]
373
+ label = self.labels[idx] # [num_tags]
374
+ # 应用数据增强
375
+ feature_aug, label_aug =feature, label
376
+ # 返回张量
377
+ return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
378
+
379
+ def _process_midi_files(self, midi_files):
380
+ print("处理 MIDI 文件以创建数据集...")
381
+ features_list = []
382
+ labels_list = []
383
+ for midi_file in midi_files:
384
+ try:
385
+ stream = music21.converter.parse(midi_file)
386
+ # 将音轨转换为特征
387
+ features = self.midi_to_features(stream)
388
+ if len(features) < self.max_length:
389
+ # 跳过长度不足的样本
390
+ continue
391
+ else:
392
+ # 将特征分割成长度为 max_length 的片段
393
+ num_segments = len(features) // self.max_length
394
+ for i in range(num_segments):
395
+ segment = features[i*self.max_length : (i+1)*self.max_length]
396
+ if len(segment) < self.max_length:
397
+ continue # 跳过不完整的片段
398
+ # 使用评估器为每个片段分配标签
399
+ tags = self.evaluator.evaluate_tags_from_features(segment)
400
+ # 二值化标签
401
+ tag_binarized = self.evaluator.mlb.transform([tags])[0]
402
+ features_list.append(segment)
403
+ labels_list.append(tag_binarized)
404
+ except Exception as e:
405
+ print(f"处理 {midi_file} 时出错: {e}")
406
+ self.features = features_list
407
+ self.labels = labels_list
408
+ # 保存数据集
409
+ try:
410
+ torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path)
411
+ print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。")
412
+ except Exception as e:
413
+ print(f"保存数据集时出错: {e}")
414
+
415
+ def midi_to_features(self, stream) -> np.ndarray:
416
+ """
417
+ 将 music21 流对象转换为特征序列。
418
+ """
419
+ features = []
420
+ for note in stream.flat.notesAndRests:
421
+ if isinstance(note, music21.note.Note):
422
+ pitch = note.pitch.midi
423
+ duration = note.quarterLength
424
+ volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量
425
+ elif isinstance(note, music21.note.Rest):
426
+ pitch = 0 # 休止符音高设为 0
427
+ duration = note.quarterLength
428
+ volume = 0
429
+ else:
430
+ continue
431
+ features.append([pitch, duration, volume])
432
+ return np.array(features, dtype=np.float32)
433
+
434
+ def _augment_data(self, feature, label):
435
+ # 实现数据增强:随机抽取、拼接、动态和快慢变化
436
+ # 例如,随机调整动态和节奏
437
+ feature_aug = np.copy(feature)
438
+ label_aug = np.copy(label)
439
+ # 随机调整音量(动态)
440
+ volume_change = np.random.uniform(0.8, 1.2)
441
+ feature_aug[:, 2] *= volume_change
442
+ feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127)
443
+ # 随机调整时值(节奏变化)
444
+ duration_change = np.random.uniform(0.9, 1.1)
445
+ feature_aug[:, 1] *= duration_change
446
+ # 根据变化调整标签
447
+ # 例如,如果节奏变化显著,调整 'tempo' 标签
448
+ if duration_change > 1.05:
449
+ # 更快的节奏
450
+ tempo_tags = ['Fast']
451
+ elif duration_change < 0.95:
452
+ # 更慢的节奏
453
+ tempo_tags = ['Slow']
454
+ else:
455
+ tempo_tags = ['Medium']
456
+ # 更新 'tempo' 标签
457
+ for tempo in ['Slow', 'Medium', 'Fast']:
458
+ label_aug[self.evaluator.all_tags.index(tempo)] = 0
459
+ tempo_index = self.evaluator.all_tags.index(tempo_tags[0])
460
+ label_aug[tempo_index] = 1
461
+ return feature_aug, label_aug
462
+
463
+ class MidiDatasetAug(Dataset):
464
+ def __init__(self, midi_files: List[str], max_length: int, dataset_path: str, evaluator: MusicTagEvaluator):
465
+ self.max_length = max_length
466
+ self.dataset_path = dataset_path
467
+ self.evaluator = evaluator
468
+ # 检查数据集文件是否存在
469
+ if os.path.exists(self.dataset_path):
470
+ # 加载已预处理的数据集
471
+ print(f"从 '{self.dataset_path}' 加载数据集")
472
+ try:
473
+ saved_data = torch.load(self.dataset_path)
474
+ self.features = saved_data['features']
475
+ self.labels = saved_data['labels']
476
+ print(f"成功加载数据集,共有 {len(self.features)} 个样本。")
477
+ except Exception as e:
478
+ print(f"加载数据集时出错: {e}")
479
+ self._process_midi_files(midi_files)
480
+ else:
481
+ # 处理 MIDI 文件并保存数据集
482
+ self._process_midi_files(midi_files)
483
+
484
+ def __len__(self):
485
+ return len(self.features)
486
+
487
+
488
+ def __getitem__(self, idx):
489
+ feature = self.features[idx] # [seq_len, input_dim]
490
+ label = self.labels[idx] # [num_tags]
491
+ # 应用数据增强
492
+ feature_aug, label_aug =self._augment_data(feature, label)
493
+ # 返回张量
494
+ return torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
495
+ torch.tensor(feature_aug, dtype=torch.float32), torch.tensor(label_aug, dtype=torch.float32)
496
+
497
+ def _process_midi_files(self, midi_files):
498
+ print("处理 MIDI 文件以创建数据集...")
499
+ features_list = []
500
+ labels_list = []
501
+ for midi_file in midi_files:
502
+ try:
503
+ stream = music21.converter.parse(midi_file)
504
+ # 将音轨转换为特征
505
+ features = self.midi_to_features(stream)
506
+ if len(features) < self.max_length:
507
+ # 跳过长度不足的样本
508
+ continue
509
+ else:
510
+ # 将特征分割成长度为 max_length 的片段
511
+ num_segments = len(features) // self.max_length
512
+ for i in range(num_segments):
513
+ segment = features[i*self.max_length : (i+1)*self.max_length]
514
+ if len(segment) < self.max_length:
515
+ continue # 跳过不完整的片段
516
+ # 使用评估器为每个片段分配标签
517
+ tags = self.evaluator.evaluate_tags_from_features(segment)
518
+ # 二值化标签
519
+ tag_binarized = self.evaluator.mlb.transform([tags])[0]
520
+ features_list.append(segment)
521
+ labels_list.append(tag_binarized)
522
+ except Exception as e:
523
+ print(f"处理 {midi_file} 时出错: {e}")
524
+ self.features = features_list
525
+ self.labels = labels_list
526
+ # 保存数据集
527
+ try:
528
+ torch.save({'features': self.features, 'labels': self.labels}, self.dataset_path)
529
+ print(f"数据集已保存至 '{self.dataset_path}',共有 {len(self.features)} 个样本。")
530
+ except Exception as e:
531
+ print(f"保存数据集时出错: {e}")
532
+
533
+ def midi_to_features(self, stream) -> np.ndarray:
534
+ """
535
+ 将 music21 流对象转换为特征序列。
536
+ """
537
+ features = []
538
+ for note in stream.flat.notesAndRests:
539
+ if isinstance(note, music21.note.Note):
540
+ pitch = note.pitch.midi
541
+ duration = note.quarterLength
542
+ volume = note.volume.velocity if note.volume.velocity else 64 # 默认音量
543
+ elif isinstance(note, music21.note.Rest):
544
+ pitch = 0 # 休止符音高设为 0
545
+ duration = note.quarterLength
546
+ volume = 0
547
+ else:
548
+ continue
549
+ features.append([pitch, duration, volume])
550
+ return np.array(features, dtype=np.float32)
551
+
552
+ def _augment_data(self, feature, label):
553
+ # 实现数据增强:随机抽取、拼接、动态和快慢变化
554
+ # 例如,随机调整动态和节奏
555
+ feature_aug = np.copy(feature)
556
+ label_aug = np.copy(label)
557
+ # 随机调整音量(动态)
558
+ volume_change = np.random.uniform(0.8, 1.2)
559
+ feature_aug[:, 2] *= volume_change
560
+ feature_aug[:, 2] = np.clip(feature_aug[:, 2], 0, 127)
561
+ # 随机调整时值(节奏变化)
562
+ duration_change = np.random.uniform(0.9, 1.1)
563
+ feature_aug[:, 1] *= duration_change
564
+ # 根据变化调整标签
565
+ # 例如,如果节奏变化显著,调整 'tempo' 标签
566
+ if duration_change > 1.05:
567
+ # 更快的节奏
568
+ tempo_tags = ['Fast']
569
+ elif duration_change < 0.95:
570
+ # 更慢的节奏
571
+ tempo_tags = ['Slow']
572
+ else:
573
+ tempo_tags = ['Medium']
574
+ # 更新 'tempo' 标签
575
+ for tempo in ['Slow', 'Medium', 'Fast']:
576
+ label_aug[self.evaluator.all_tags.index(tempo)] = 0
577
+ tempo_index = self.evaluator.all_tags.index(tempo_tags[0])
578
+ label_aug[tempo_index] = 1
579
+ return feature_aug, label_aug
580
+
581
+
582
+
583
+ class RandomDataset(Dataset):
584
+ def __init__(self, size: int, max_length: int):
585
+ """
586
+ 随机生成数据集。
587
+
588
+ 参数:
589
+ size (int): 数据集大小。
590
+ max_length (int): 每个样本的序列长度。
591
+ """
592
+ self.size = size
593
+ self.max_length = max_length
594
+
595
+ def __len__(self):
596
+ return self.size
597
+
598
+ def __getitem__(self, idx):
599
+ # 随机音高范围在21(A0)到108(C8)之间
600
+ pitch = np.random.randint(21, 109, size=(self.max_length, 1)).astype(np.float32)
601
+ # 随机选择可接受的时值
602
+ acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
603
+ duration = np.random.choice(acceptable_durations, size=(self.max_length, 1)).astype(np.float32)
604
+ # 随机音量在60到100之间
605
+ volume = np.random.randint(40, 70, size=(self.max_length, 1)).astype(np.float32)
606
+ features = np.concatenate([pitch, duration, volume], axis=-1) # [max_length, 3]
607
+ return torch.tensor(features, dtype=torch.float32)
608
+
609
+ class MusicGenerator:
610
+ def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None):
611
+ self.model = model.to(device)
612
+ self.evaluator = evaluator
613
+ self.device = device
614
+ self.model_path = model_path
615
+ self.optimizer = optimizer
616
+ self.optimizer_path = optimizer_path
617
+ self.writer = writer
618
+ self._load_model()
619
+ # 定义归一化和反归一化参数
620
+ self.min_pitch = 21
621
+ self.max_pitch = 108
622
+ self.min_duration = 0.15
623
+ self.max_duration = 1.5
624
+ self.min_volume = 40
625
+ self.max_volume = 85
626
+
627
+ def _load_model(self):
628
+ """自动载入已存在的模型权重,如果存在的话。"""
629
+ if os.path.exists(self.model_path):
630
+ try:
631
+ self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
632
+ self.model.to(self.device)
633
+ self.model.eval()
634
+ print(f"已成功载入模型权重从 '{self.model_path}'。")
635
+ except Exception as e:
636
+ print(f"载入模型权重时出错: {e},将初始化新模型。")
637
+ else:
638
+ print("未找到已保存的模型,将初始化新模型。")
639
+
640
+ # 加载优化器状态
641
+ if self.optimizer and self.optimizer_path and os.path.exists(self.optimizer_path):
642
+ try:
643
+ self.optimizer.load_state_dict(torch.load(self.optimizer_path, map_location=self.device))
644
+ print(f"已成功载入优化器状态从 '{self.optimizer_path}'。")
645
+ except Exception as e:
646
+ print(f"载入优化器状态时出错: {e},将初始化新优化器。")
647
+ else:
648
+ if self.optimizer and self.optimizer_path:
649
+ print("未找到已保存的优化器状态,将初始化新优化器。")
650
+
651
+ def save_model(self, epoch: int, loss: float):
652
+ """保存当前模型的权重和优化器状态。"""
653
+ try:
654
+ torch.save(self.model.state_dict(), self.model_path, _use_new_zipfile_serialization=False)
655
+ if self.optimizer and self.optimizer_path:
656
+ torch.save(self.optimizer.state_dict(), self.optimizer_path, _use_new_zipfile_serialization=False)
657
+ print(f"模型和优化器已保存至 '{self.model_path}' 和 '{self.optimizer_path}'。")
658
+ if self.writer:
659
+ self.writer.add_scalar('Loss/Save', loss, epoch)
660
+ except Exception as e:
661
+ print(f"保存模型或优化器时出错: {e}")
662
+
663
+ def train_epoch(self, dataloader: DataLoader, optimizer, criterion_music, criterion_tags, epoch: int):
664
+ """
665
+ 训练一个 epoch。
666
+ """
667
+ self.model.train()
668
+ total_loss = 0.0
669
+ for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)):
670
+ batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim]
671
+ batch_labels = batch_labels.to(self.device) # [batch_size, num_tags]
672
+ inputs = batch_features[:, :-1, :] # [batch_size, seq_len-1, input_dim]
673
+ targets = batch_features[:, -1, :] # [batch_size, input_dim]
674
+
675
+ optimizer.zero_grad()
676
+ music_output, tag_probabilities = self.model(inputs) # 音乐输出: [batch, seq_len-1, output_dim]
677
+
678
+ # 只对最后一个时间步的输出进行损失计算
679
+ loss_music = criterion_music(music_output[:, -1, :], targets)
680
+
681
+ # 使用数据集中的标签
682
+ loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels)
683
+
684
+ # 总损失
685
+ loss = loss_music + loss_tags
686
+ loss.backward()
687
+
688
+ # 梯度裁剪
689
+ clip_grad_norm_(self.model.parameters(), max_norm=1.0)
690
+
691
+ optimizer.step()
692
+
693
+ total_loss += loss.item()
694
+ if self.writer:
695
+ self.writer.add_scalar('Loss/Train', loss.item(), epoch * len(dataloader) + batch_idx)
696
+
697
+ avg_loss = total_loss / len(dataloader)
698
+ print(f"Epoch {epoch} 平均损失: {avg_loss:.4f}")
699
+ return avg_loss
700
+
701
+ def train_epoch_gan(self, dataloader, optimizer_generator, optimizer_discriminator, criterion_music, criterion_tags, criterion_discriminator, discriminator, epoch):
702
+ """
703
+ 使用对抗训练的方法训练一个 epoch。
704
+ """
705
+ self.model.train()
706
+ discriminator.train()
707
+ total_loss = 0.0
708
+
709
+ for batch_idx, (batch_features, batch_labels) in enumerate(tqdm(dataloader, desc=f"Epoch {epoch}", leave=False)):
710
+ batch_features = batch_features.to(self.device) # [batch_size, seq_len, input_dim]
711
+ batch_labels = batch_labels.to(self.device) # [batch_size, num_tags]
712
+ batch_size = batch_features.size(0)
713
+ seq_len = batch_features.size(1)
714
+ # ---------------------
715
+ # 训练判别器
716
+ # ---------------------
717
+
718
+ # 使用真实数据
719
+ real_data = batch_features # [batch_size, seq_len, input_dim]
720
+ real_labels = torch.ones(batch_size, 1).to(self.device)
721
+
722
+ # 使用生成器生成假数据
723
+ noise = torch.rand(batch_size, seq_len, 3).to(self.device) # 随机噪声在 [0,1],与归一化后的特征一致
724
+ generated_features = torch.zeros_like(batch_features).to(self.device)
725
+ for i in range(seq_len):
726
+ input_noise = noise[:, :i+1, :]
727
+ fake_data, _ = self.model(input_noise)
728
+ generated_features[:, i, :] = fake_data[:, -1, :]
729
+
730
+ fake_data = generated_features.detach() # [batch_size, seq_len, input_dim]
731
+ fake_labels = torch.zeros(batch_size, 1).to(self.device)
732
+
733
+ # 计算判别器在真实数据上的损失
734
+ optimizer_discriminator.zero_grad()
735
+ output_real = discriminator(real_data)
736
+ loss_real = criterion_discriminator(output_real, real_labels)
737
+
738
+ # 计算判别器在假数据上的损失
739
+ output_fake = discriminator(fake_data)
740
+ loss_fake = criterion_discriminator(output_fake, fake_labels)
741
+
742
+ # 总损失并反向传播
743
+ loss_discriminator = (loss_real + loss_fake) / 2
744
+ loss_discriminator.backward()
745
+ optimizer_discriminator.step()
746
+
747
+ # ---------------------
748
+ # 训练生成器
749
+ # ---------------------
750
+
751
+ optimizer_generator.zero_grad()
752
+ # 生成假数据并计算生成器的损失,目标是让判别器相信这些数据是真实的
753
+ output_fake_for_generator = discriminator(fake_data)
754
+ loss_generator_adv = criterion_discriminator(output_fake_for_generator, real_labels) # 生成器的对抗损失
755
+
756
+ # 计算生成器的音乐特征和标签损失
757
+ music_output, tag_probabilities = self.model(noise)
758
+ targets = batch_features[:, -1, :] # 真实的最后一个特征
759
+ loss_music = criterion_music(music_output[:, -1, :], targets)
760
+
761
+ # 使用数据集中的标签
762
+ loss_tags = criterion_tags(tag_probabilities[:, -1, :], batch_labels)
763
+
764
+ # 总损失
765
+ loss_generator = loss_generator_adv + loss_music + loss_tags
766
+ loss_generator.backward()
767
+
768
+ # 梯度裁剪
769
+ clip_grad_norm_(self.model.parameters(), max_norm=1.0)
770
+
771
+ optimizer_generator.step()
772
+
773
+ total_loss += loss_generator.item()
774
+ if self.writer:
775
+ #self.writer.add_scalar('Loss/Generator', loss_generator.item(), epoch * len(dataloader) + batch_idx)
776
+ #self.writer.add_scalar('Loss/Discriminator', loss_discriminator.item(), epoch * len(dataloader) + batch_idx)
777
+ pass
778
+
779
+ avg_loss = total_loss / len(dataloader)
780
+ print(f"Epoch {epoch} 平均生成器损失: {avg_loss:.4f}")
781
+ return avg_loss
782
+
783
+
784
+ def generate_music(self, tag_conditions: dict={
785
+ 'emotions': 'Neutral',
786
+ 'genres': 'Classical',
787
+ 'tempo': 'Medium',
788
+ 'instrumentation': 'Piano',
789
+ 'harmony': 'Simple',
790
+ 'dynamics': 'Dynamic',
791
+ 'rhythm': 'Simple' # 或 'Complex'
792
+ }, max_length=100, temperature=1.0) -> music21.stream.Stream:
793
+ """
794
+ 根据标签生成音乐。
795
+ """
796
+ self.model.eval()
797
+ acceptable_durations = [0.25, 0.333, 0.5, 0.666, 0.75, 1.0, 1.5, 2.0, 3.0, 4.0]
798
+ generated_features = []
799
+
800
+ with torch.no_grad():
801
+ # 随机选择一个调性
802
+ key_str = choice(['C', 'G', 'D', 'A', 'E', 'B', 'F#', 'C#', 'F', 'Bb', 'Eb', 'Ab', 'Db', 'Gb', 'Cb'])
803
+ scale_notes = get_scale_notes(key_str)
804
+
805
+ # 初始输入(随机特征)
806
+ input_feature = torch.zeros(1, 1, 3).to(self.device) # [batch_size=1, seq_len=1, input_dim=3]
807
+
808
+ for _ in range(max_length):
809
+ music_output, tag_probabilities = self.model(input_feature) # [1, seq_len, 3] and [1, seq_len, num_tags]
810
+ music_output_np = music_output.cpu().numpy()[0, -1]
811
+
812
+ # 应用温度控制
813
+ music_output_np = music_output_np / temperature
814
+
815
+ # 使用概率分布进行采样
816
+ pitch = int(round(music_output_np[0]))
817
+ duration = music_output_np[1]
818
+ volume = int(round(music_output_np[2]))
819
+
820
+ # 增加随机变动
821
+ pitch += int(np.random.uniform(-2, 2))
822
+ pitch = max(21, min(108, pitch)) # 限制在钢琴键范围内
823
+ # 将音高映射到最近的音阶音符
824
+ if pitch not in scale_notes:
825
+ pitch = min(scale_notes, key=lambda x: abs(x - pitch))
826
+ duration += np.random.uniform(-0.1, 0.1)
827
+ try:
828
+ duration = min(acceptable_durations, key=lambda x: abs(x - duration))
829
+ except ValueError:
830
+ duration = 1.0 # 默认时值
831
+ volume += int(np.random.uniform(-10, 10))
832
+ volume = max(70, min(100, volume)) # 限制音量范围
833
+
834
+ # 保存特征
835
+ generated_features.append([pitch, duration, volume])
836
+
837
+ # 更新输入
838
+ next_input = torch.tensor([[pitch, duration, volume]], dtype=torch.float32).to(self.device).unsqueeze(0) # [1, 1, 3]
839
+ input_feature = torch.cat((input_feature, next_input), dim=1) # 增加序列长度
840
+
841
+ # 转换为 numpy 数组
842
+ generated_features_array = np.array(generated_features, dtype=np.float32)
843
+ generated_stream = composer_from_features(generated_features_array, key_str)
844
+
845
+ # 评估标签
846
+ tag_scores = self.evaluator.evaluate_tags(generated_stream)
847
+ print("生成的音乐标签:", tag_scores)
848
+
849
+ # 根据情感进行判断并保存
850
+ high_score_emotions = ['Happy', 'Peaceful']
851
+ if tag_scores.get('emotions') in high_score_emotions:
852
+ # 将生成的 MIDI 转换为 WAV
853
+ midi_filename = f'high_score_{int(time.time())}.mid'
854
+ generated_stream.write('midi', fp=os.path.join(Gbase, midi_filename))
855
+ wav_file = self.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f'high_score_{int(time.time())}.wav'))
856
+ print(f"高评分音乐已保存为 WAV 文件: '{wav_file}'")
857
+
858
+ return generated_stream
859
+
860
+ def addMusicToVideo(self, videoPath, tagConditions={
861
+ 'emotions': 'Neutral',
862
+ 'genres': 'Classical',
863
+ 'tempo': 'Medium',
864
+ 'instrumentation': 'Piano',
865
+ 'harmony': 'Simple',
866
+ 'dynamics': 'Dynamic',
867
+ 'rhythm': 'Simple' # 或 'Complex'
868
+ }, outputPath=None):
869
+ """
870
+ 根据指定的标签条件生成音乐,并将其附加到视频中,确保音乐的长度与视频一致。
871
+
872
+ 参数:
873
+ videoPath (str): 输入视频的路径。
874
+ tagConditions (dict): 用于生成音乐的标签条件。
875
+ outputPath (str, optional): 输出视频的路径。如果未指定,将在原路径基础上添加 '_with_music'。
876
+
877
+ 返回:
878
+ str: 输出的视频路径。
879
+ """
880
+ # 1. 获取视频时长
881
+ try:
882
+ video = VideoFileClip(videoPath)
883
+ duration = video.duration
884
+ print(f"视频时长: {duration} 秒。")
885
+ except Exception as e:
886
+ print(f"无法载入视频: {e}")
887
+ return None
888
+ if not outputPath:
889
+ base, ext = os.path.splitext(videoPath)
890
+ outputPath = f"{base}_with_music{ext}"
891
+ if os.path.exists (outputPath):return outputPath
892
+ # 2. 初始化音频拼接
893
+ combined_audio = AudioSegment.silent(duration=0) # 初始化为空音频
894
+ total_generated_duration = 0 # 总生成时长(毫秒)
895
+ chunk_duration_seconds = 10 # 每次生成音讯的预估时长(秒),根据需要调整
896
+ crossfade_duration = 500 # 淡入淡出持续时间(毫秒)
897
+
898
+ # 3. 逐段生成音频
899
+ print("逐段生成音乐中...")
900
+ while total_generated_duration < duration * 1000: # pydub 使用毫秒
901
+ # 根据剩余时长生成音乐,确保不生成过多
902
+ remaining_duration_ms = duration * 1000 - total_generated_duration
903
+ remaining_duration_seconds = remaining_duration_ms / 1000.0
904
+ current_chunk_length = min(chunk_duration_seconds, remaining_duration_seconds)
905
+
906
+ # 计算所需的音符数量,假设每个音符平均约0.5秒
907
+ estimated_max_length = int(current_chunk_length / 0.5) * 2 # 调整因子根据实际情况
908
+
909
+ # 生成音乐流
910
+ generated_stream = self.generate_music(max_length=100)
911
+
912
+ # 转换为 WAV 文件
913
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as wav_temp:
914
+ wav_filename = wav_temp.name
915
+ wav_path = self.custom_midi_to_wav(generated_stream, wav_filename)
916
+ print(f"生成的 WAV 已保存为 '{wav_path}'。")
917
+
918
+ # 加载生成的音频
919
+ try:
920
+ generated_audio = AudioSegment.from_wav(wav_path)
921
+ except Exception as e:
922
+ print(f"加载生成的音频时出错: {e}")
923
+ os.remove(wav_path)
924
+ #break
925
+
926
+ # 拼接音频,应用淡入淡出效果
927
+ if len(combined_audio) == 0:
928
+ # 第一段音频,仅应用淡入
929
+ generated_audio = generated_audio.fade_in(crossfade_duration)
930
+ combined_audio += generated_audio
931
+ else:
932
+ # 之后的音频段,应用淡出和淡入,并设置 crossfade
933
+ generated_audio = generated_audio.fade_in(crossfade_duration)
934
+ combined_audio = combined_audio.append(generated_audio, crossfade=crossfade_duration)
935
+
936
+ total_generated_duration = len(combined_audio)
937
+
938
+ # 删除临时 WAV 文件
939
+ try:
940
+ os.remove(wav_path)
941
+ print(f"已删除临时 WAV 文件 '{wav_path}'。")
942
+ except Exception as e:
943
+ print(f"删除临时 WAV 文件时出错: {e}")
944
+
945
+ # 4. 剪切音频以匹配视频时长
946
+ final_audio = combined_audio[:int(duration * 1000)] # pydub 使用毫秒为单位
947
+ final_wav_path = tempfile.mktemp(suffix='.wav')
948
+ final_audio.export(final_wav_path, format="wav")
949
+ print(f"最终剪切后的 WAV 已保存为 '{final_wav_path}'。")
950
+
951
+ # 5. 定义输出视频路径
952
+ if not outputPath:
953
+ base, ext = os.path.splitext(videoPath)
954
+ outputPath = f"{base}_with_music{ext}"
955
+
956
+ # 6. 使用 moviepy 将音频与视频结合
957
+ try:
958
+ # 载入视频和音频
959
+ video_clip = VideoFileClip(videoPath)
960
+ audio_clip = AudioFileClip(final_wav_path)
961
+
962
+ # 设置音频,确保音频长度与视频一致
963
+ audio_clip = audio_clip.set_duration(video_clip.duration)
964
+
965
+ # 将音频附加到视频
966
+ video_with_audio = video_clip.set_audio(audio_clip)
967
+
968
+ # 输出最终视频
969
+ video_with_audio.write_videofile(outputPath, codec='libx264', audio_codec='aac', verbose=False, logger=None)
970
+ print(f"输出视频已保存为 '{outputPath}'。")
971
+ except Exception as e:
972
+ print(f"结合视频和音频时出错: {e}")
973
+ return None
974
+ finally:
975
+ # 清理 moviepy 生成的资源
976
+ if 'video_clip' in locals():
977
+ video_clip.close()
978
+ if 'audio_clip' in locals():
979
+ audio_clip.close()
980
+ if 'video_with_audio' in locals():
981
+ video_with_audio.close()
982
+
983
+ # 7. 清理临时文件
984
+ try:
985
+ os.remove(final_wav_path)
986
+ print("最终临时 WAV 文件已删除。")
987
+ except Exception as e:
988
+ print(f"删除最终临时 WAV 文件时出错: {e}")
989
+
990
+ return outputPath
991
+
992
+
993
+
994
+
995
+ def custom_midi_to_wav(self, stream: music21.stream.Stream, wav_filename: str, sample_rate=44100) -> str:
996
+ """
997
+ 自定义的 MIDI 到 WAV 转换函数,使用数学公式生成高质量的音频。
998
+ 改进后:声音更加悦耳,符合音符、音阶、乐器的基本要求。
999
+ """
1000
+ import math
1001
+
1002
+ # 合成参数
1003
+ envelope_attack = 0.01 # 攻击时间
1004
+ envelope_decay = 0.1 # 衰减时间
1005
+ envelope_sustain = 0.8 # 持续水平
1006
+ envelope_release = 0.2 # 释放时间
1007
+
1008
+ # 获取节奏信息
1009
+ metronome_marks = list(stream.metronomeMarkBoundaries())
1010
+ bpm = 120 # 默认 BPM
1011
+ if metronome_marks:
1012
+ # 检查是否存在 MetronomeMark 对象
1013
+ for mark in metronome_marks:
1014
+ if isinstance(mark[2], music21.tempo.MetronomeMark) and mark[2].number:
1015
+ bpm = mark[2].number
1016
+ break
1017
+
1018
+ # 生成时间轴
1019
+ notes = list(stream.flat.getElementsByClass(['Note', 'Chord', 'Rest']))
1020
+ if not notes:
1021
+ print("没有音符可生成音频。")
1022
+ return ""
1023
+
1024
+ # 计算整体时长
1025
+ total_duration = stream.duration.quarterLength * 60 / bpm
1026
+ total_samples = int(total_duration * sample_rate) + 1
1027
+ audio = np.zeros(total_samples)
1028
+
1029
+ current_time = 0
1030
+
1031
+ # 定义乐器的谐波系数,模拟钢琴的谐波
1032
+ harmonic_coeffs = [1.0, 0.5, 0.25, 0.1, 0.05]
1033
+
1034
+ for element in notes:
1035
+ if isinstance(element, music21.note.Rest):
1036
+ # 休止符,更新当前时间
1037
+ duration = element.quarterLength * 60 / bpm # 秒
1038
+ current_time += duration
1039
+ continue
1040
+
1041
+ elif isinstance(element, music21.note.Note):
1042
+ frequencies = [element.pitch.frequency]
1043
+ elif isinstance(element, music21.chord.Chord):
1044
+ frequencies = [p.frequency for p in element.pitches]
1045
+ else:
1046
+ continue
1047
+
1048
+ duration = element.quarterLength * 60 / bpm # 秒
1049
+ # 音量固定为70%
1050
+ volume = 0.6
1051
+
1052
+ # 生成波形时间轴
1053
+ t = np.linspace(0, duration, int(duration * sample_rate), False)
1054
+
1055
+ waveform = np.zeros_like(t)
1056
+ for freq in frequencies:
1057
+ note_waveform = np.zeros_like(t)
1058
+ for idx, coeff in enumerate(harmonic_coeffs):
1059
+ harmonic_freq = freq * (idx + 1)
1060
+ note_waveform += coeff * np.sin(2 * np.pi * harmonic_freq * t)
1061
+ waveform += note_waveform
1062
+
1063
+ # 归一化振幅(避免多个频率叠加导致音量过高)
1064
+ waveform /= len(frequencies) * sum(harmonic_coeffs)
1065
+
1066
+ # 添加 ADSR 包络
1067
+ attack_samples = int(envelope_attack * sample_rate)
1068
+ decay_samples = int(envelope_decay * sample_rate)
1069
+ release_samples = int(envelope_release * sample_rate)
1070
+ sustain_samples = len(waveform) - attack_samples - decay_samples - release_samples
1071
+ if sustain_samples < 0:
1072
+ # 调整 ADSR 以适应短音符
1073
+ total_envelope = envelope_attack + envelope_decay + envelope_release
1074
+ attack_ratio = envelope_attack / total_envelope
1075
+ decay_ratio = envelope_decay / total_envelope
1076
+ release_ratio = envelope_release / total_envelope
1077
+ attack_samples = int(len(waveform) * attack_ratio)
1078
+ decay_samples = int(len(waveform) * decay_ratio)
1079
+ release_samples = len(waveform) - attack_samples - decay_samples
1080
+ sustain_samples = 0
1081
+
1082
+ envelope = np.concatenate([
1083
+ np.linspace(0, 1, attack_samples, False),
1084
+ np.linspace(1, envelope_sustain, decay_samples, False),
1085
+ np.full(sustain_samples, envelope_sustain),
1086
+ np.linspace(envelope_sustain, 0, release_samples, False)
1087
+ ])
1088
+
1089
+ # 调整 envelope 长度
1090
+ envelope = envelope[:len(waveform)]
1091
+
1092
+ waveform *= envelope
1093
+ waveform *= volume
1094
+
1095
+ # 计算样本索引
1096
+ start_sample = int(current_time * sample_rate)
1097
+ end_sample = start_sample + len(waveform)
1098
+ if end_sample > total_samples:
1099
+ end_sample = total_samples
1100
+ waveform = waveform[:end_sample - start_sample]
1101
+
1102
+ # 合成音频
1103
+ audio[start_sample:end_sample] += waveform
1104
+
1105
+ # 更新当前时间
1106
+ current_time += duration
1107
+
1108
+ # 防止削波
1109
+ max_val = np.max(np.abs(audio))
1110
+ if max_val > 1:
1111
+ audio /= max_val
1112
+
1113
+ # 将音频转换为16位整数
1114
+ audio_int16 = np.int16(audio * 32767)
1115
+
1116
+ # 写入 WAV 文件
1117
+ wav_path = os.path.join(os.getcwd(), wav_filename)
1118
+ with wave.open(wav_path, 'w') as wav_file:
1119
+ n_channels = 2
1120
+ sampwidth = 2 # 2 bytes for int16
1121
+ framerate = sample_rate
1122
+ n_frames = len(audio_int16)
1123
+ comptype = "NONE"
1124
+ compname = "not compressed"
1125
+ wav_file.setparams((n_channels, sampwidth, framerate, n_frames, comptype, compname))
1126
+ wav_file.writeframes(audio_int16.tobytes())
1127
+
1128
+ return wav_path
1129
+
1130
+
1131
+ class AdvancedMusicGenerator(MusicGenerator):
1132
+ def __init__(self, model: nn.Module, evaluator, device: torch.device, model_path: str, optimizer=None, optimizer_path: str=None, writer: SummaryWriter=None):
1133
+ super().__init__(model, evaluator, device, model_path, optimizer, optimizer_path, writer)
1134
+ # 可以在此添加更多的初始化参数或方法
1135
+
1136
+ # 这里可以覆盖或新增更多方法以进一步增强功能
1137
+
1138
+ def trainModel():
1139
+ # 初始化 TensorBoard
1140
+ writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs'))
1141
+
1142
+ # 初始化标签评估器
1143
+ evaluator = MusicTagEvaluator.load(EvaluatorPath)
1144
+
1145
+ # 获取唯一的标签数量
1146
+ num_tags = len(evaluator.all_tags)
1147
+
1148
+ # 定义模型参数
1149
+ input_dim = 3 # 音高、时值和音量
1150
+ d_model = 512 # 增加 Transformer 模型维度
1151
+ nhead = 8 # 多头注意力头数
1152
+ num_encoder_layers = 8 # 增加 Transformer 编码器层数
1153
+ dim_feedforward = 2048 # 增加前馈层维度
1154
+ output_dim = 3 # 预测音高、时值和音量
1155
+
1156
+ # 初始化模型
1157
+ model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags)
1158
+
1159
+ # 设置设备
1160
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1161
+ model.to(device)
1162
+ print(f"使用设备: {device}")
1163
+
1164
+ # 加载 MIDI 文件
1165
+ midi_directory = os.path.join(Gbase, 'generateMIDI')
1166
+ midi_files = []
1167
+ if os.path.exists(midi_directory):
1168
+ midi_files = [os.path.join(midi_directory, f) for f in os.listdir(midi_directory) if f.endswith('.mid') or f.endswith('.midi')]
1169
+ print(f"在目录 '{midi_directory}' 中找到 {len(midi_files)} 个 MIDI 文件用于训练。")
1170
+ else:
1171
+ print(f"MIDI 文件目录 '{midi_directory}' 不存在,请确保该目录存在并包含 MIDI 文件。")
1172
+ return # 退出函数
1173
+
1174
+ # 创建数据集和数据加载器
1175
+ max_length = 100 # 根据需求调整
1176
+ dataset_path = os.path.join(Gbase, 'mymusic.dataset')
1177
+ dataset = MidiDataset(midi_files, max_length, dataset_path, evaluator)
1178
+ datasetAug = MidiDatasetAug(midi_files, max_length, dataset_path, evaluator)
1179
+ # 定义要采样的样本数量
1180
+ sample_size = 30000 if torch.cuda.is_available() else 15000
1181
+ sample_size1 = int(sample_size/10)
1182
+ sample_size2 = int(sample_size/300)
1183
+ total_samples = len(dataset)
1184
+ if total_samples < sample_size:
1185
+ print(f"数据集中只有 {total_samples} 个样本,无法采样 {sample_size} 个。请检查数据集。")
1186
+ return
1187
+
1188
+ # 定义训练周期和学习率
1189
+ epochs = 4 # 根据需要调整
1190
+ learning_rate = 0.001
1191
+ batch_size= 16 if torch.cuda.is_available() else 4
1192
+
1193
+ # 初始化生成器
1194
+ optimizer_generator = optim.AdamW(model.parameters(), lr=learning_rate * 0.1)
1195
+ generator = MusicGenerator(model, evaluator, device, model_path=ModelPath, optimizer=optimizer_generator, optimizer_path=OptimizerPath, writer=writer)
1196
+
1197
+ # 初始化判别器
1198
+ discriminator = Discriminator(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward).to(device)
1199
+ optimizer_discriminator = optim.AdamW(discriminator.parameters(), lr=learning_rate)
1200
+ criterion_discriminator = nn.BCELoss()
1201
+
1202
+ # 尝试加载判别器模型和优化器状态
1203
+ if os.path.exists(DiscriminatorModelPath):
1204
+ discriminator.load_state_dict(torch.load(DiscriminatorModelPath, map_location=device))
1205
+ print(f"已成功载入判别器模型权重从 '{DiscriminatorModelPath}'。")
1206
+ if os.path.exists(DiscriminatorOptimizerPath):
1207
+ optimizer_discriminator.load_state_dict(torch.load(DiscriminatorOptimizerPath, map_location=device))
1208
+ print(f"已成功载入判别器优化器状态从 '{DiscriminatorOptimizerPath}'。")
1209
+ indices = list(range(total_samples))
1210
+ random_indices = random.sample(indices, sample_size)
1211
+ random_indices1 = random.sample(indices, sample_size1)
1212
+ random_indices2 = random.sample(indices, sample_size2)
1213
+ random_indicesAug= random.sample(indices, sample_size)
1214
+ sampler = SubsetRandomSampler(random_indices)
1215
+ dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=2)
1216
+ sampler1 = SubsetRandomSampler(random_indices1)
1217
+ dataloaderAug = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler, num_workers=2)
1218
+ dataloader1 = DataLoader(datasetAug, batch_size=8, sampler=sampler1, num_workers=2)
1219
+ sampler2 = SubsetRandomSampler(random_indices2)
1220
+ dataloader2 = DataLoader(dataset, batch_size=batch_size, sampler=sampler2, num_workers=2)
1221
+ sampler3 = SubsetRandomSampler(random_indices2)
1222
+ dataloader3 = DataLoader(datasetAug, batch_size=batch_size, sampler=sampler2, num_workers=2)
1223
+ # 开始对抗训练
1224
+ print("開始訓練...")
1225
+ for epoch in range(1, epochs + 1):
1226
+ try:
1227
+ avg_loss = generator.train_epoch(
1228
+ dataloader,
1229
+ optimizer_generator,
1230
+ nn.MSELoss(),
1231
+ nn.BCELoss(),
1232
+ epoch
1233
+ )
1234
+
1235
+ # 保存判别器模型和优化器
1236
+ """
1237
+ generator.save_model(epoch, avg_loss)
1238
+ torch.save(discriminator.state_dict(), DiscriminatorModelPath)
1239
+ torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
1240
+ print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
1241
+
1242
+ """
1243
+
1244
+
1245
+
1246
+ except KeyboardInterrupt:
1247
+ print("训练过程被手动中断。")
1248
+ break
1249
+ except Exception as e:
1250
+ print(f"在训练 epoch {epoch} 时发生错误: {e}")
1251
+
1252
+
1253
+ if epoch!=4:continue
1254
+ print("開始強化訓練...")
1255
+ try:
1256
+ avg_loss = generator.train_epoch(
1257
+ dataloaderAug,
1258
+ optimizer_generator,
1259
+ nn.MSELoss(),
1260
+ nn.BCELoss(),
1261
+ epoch
1262
+ )
1263
+
1264
+ # 保存判别器模型和优化器
1265
+ #"""
1266
+ generator.save_model(epoch, avg_loss)
1267
+ torch.save(discriminator.state_dict(), DiscriminatorModelPath)
1268
+ torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
1269
+ print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
1270
+ # 保存评估器
1271
+ #evaluator.save(EvaluatorPath)
1272
+ #"""
1273
+ except KeyboardInterrupt:
1274
+ print("训练过程被手动中断。")
1275
+ break
1276
+ except Exception as e:
1277
+ print(f"在训练 epoch {epoch} 时发生错误: {e}")
1278
+ print("開始對抗訓練...")
1279
+ try:
1280
+ avg_loss = generator.train_epoch_gan(
1281
+ dataloader1,
1282
+ optimizer_generator,
1283
+ optimizer_discriminator,
1284
+ nn.MSELoss(),
1285
+ nn.BCELoss(),
1286
+ criterion_discriminator,
1287
+ discriminator,
1288
+ epoch
1289
+ )
1290
+ """
1291
+ generator.save_model(epoch, avg_loss)
1292
+ # 保存判别器模型和优化器
1293
+ torch.save(discriminator.state_dict(), DiscriminatorModelPath)
1294
+ torch.save(optimizer_discriminator.state_dict(), DiscriminatorOptimizerPath)
1295
+ print(f"判别器模型和优化器已保存至 '{DiscriminatorModelPath}' 和 '{DiscriminatorOptimizerPath}'。")
1296
+ # 保存评估器
1297
+ #evaluator.save(EvaluatorPath)
1298
+ #"""
1299
+ except KeyboardInterrupt:
1300
+ print("训练过程被手动中断。")
1301
+ break
1302
+ except Exception as e:
1303
+ print(f"在训练 epoch {epoch} 时发生错误: {e}")
1304
+ continue # 继续下一个 epoch
1305
+
1306
+ # 关闭 TensorBoard writer
1307
+ writer.close()
1308
+
1309
+ def loadMusicGenerator():
1310
+ # 初始化 TensorBoard
1311
+ writer = SummaryWriter(log_dir=os.path.join(Gbase, 'runs'))
1312
+
1313
+ # 加载标签评估器
1314
+ evaluator = MusicTagEvaluator()
1315
+ #.load(EvaluatorPath)
1316
+
1317
+ # 获取唯一的标签数量
1318
+ num_tags = len(evaluator.all_tags)
1319
+
1320
+ # 定义模型参数
1321
+ input_dim = 3 # 音高、时值和音量
1322
+ d_model = 512 # 必须与训练时的模型参数一致
1323
+ nhead = 8
1324
+ num_encoder_layers = 8
1325
+ dim_feedforward = 2048
1326
+ output_dim = 3
1327
+ # 设置设备
1328
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1329
+ print(f"使用设备: {device}")
1330
+
1331
+ # 初始化模型
1332
+ model = MusicGenerationModel(input_dim, d_model, nhead, num_encoder_layers, dim_feedforward, output_dim, num_tags).to(device)
1333
+
1334
+
1335
+
1336
+ # 初始化生成器
1337
+ generator = AdvancedMusicGenerator(model, evaluator, device, model_path=ModelPath, writer=writer)
1338
+ return generator, evaluator
1339
+
1340
+
1341
+ MyMusicGenerator, MyMusicTagEvaluator = loadMusicGenerator()
1342
+
1343
+ import gradio as gr
1344
+ import numpy as np
1345
+ import time
1346
+ import os
1347
+
1348
+ # Assuming your existing functions and setup are defined above
1349
+
1350
+ def generate_music(*tags, use_random=False):
1351
+ if use_random:
1352
+ tags_dict = randomMusicTags()
1353
+ else:
1354
+ # Assuming the order of tags matches with MUSIC_TAGS.keys()
1355
+ tags_dict = dict(zip(MUSIC_TAGS.keys(), tags))
1356
+
1357
+ # Generate music using your existing function (which should return a path to a wav file)
1358
+ generated_stream = MyMusicGenerator.generate_music(tag_conditions=tags_dict, max_length=130, temperature=np.random.uniform(0.7, 1.1))
1359
+
1360
+ # Save the generated stream as a MIDI file
1361
+ midi_filename = f"music_{int(time.time())}.mid"
1362
+ mid_path = os.path.join(Gbase, midi_filename)
1363
+ generated_stream.write('midi', fp=mid_path)
1364
+
1365
+ # Convert MIDI to WAV (make sure this function exists)
1366
+ wav_file = MyMusicGenerator.custom_midi_to_wav(generated_stream, os.path.join(Gbase, f"{midi_filename[:-4]}.wav"))
1367
+
1368
+ return wav_file, tags_dict
1369
+
1370
+ # Define the interface
1371
+ with gr.Blocks() as demo:
1372
+ gr.Markdown("# Music Generation with Tags")
1373
+
1374
+ with gr.Row():
1375
+ with gr.Column():
1376
+ # List comprehension to create dropdowns for each tag category
1377
+ tag_inputs = [
1378
+ gr.Dropdown(value=MUSIC_TAGS[category][0] ,choices=MUSIC_TAGS[category], label=category.capitalize())
1379
+ for category in MUSIC_TAGS.keys()
1380
+ ]
1381
+ with gr.Column():
1382
+ use_random = gr.Checkbox(label="Use Random Tags")
1383
+ generate_btn = gr.Button("Generate Music")
1384
+ output_audio = gr.Audio(label="Generated Music")
1385
+ output_tags = gr.JSON(label="Generated Tags")
1386
+
1387
+ # Pass the list of dropdowns directly instead of using gr.Group
1388
+ generate_btn.click(
1389
+ fn=generate_music,
1390
+ inputs=[*tag_inputs, use_random],
1391
+ outputs=[output_audio, output_tags]
1392
+ )
1393
+
1394
+ # Launch the interface
1395
+ demo.launch()