admin commited on
Commit
93fe177
Β·
1 Parent(s): bba92ce

upd gr ver

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +11 -11
  3. generate.py +1 -1
  4. model.py +14 -27
  5. utils.py +15 -7
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸŽΆπŸ˜†πŸ˜ πŸ˜Ÿ
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.32.0
8
  app_file: app.py
9
  pinned: false
10
  license: lgpl-3.0
 
4
  colorFrom: indigo
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 6.6.0
8
  app_file: app.py
9
  pinned: false
10
  license: lgpl-3.0
app.py CHANGED
@@ -162,8 +162,6 @@ if __name__ == "__main__":
162
  gr.Video(
163
  "./demo.mp4" if EN_US else "./src/tutorial.mp4",
164
  label=_L("Video demo"),
165
- show_download_button=False,
166
- show_share_button=False,
167
  )
168
  gr.Markdown(
169
  f"## {_L('Cite')}"
@@ -215,9 +213,6 @@ if __name__ == "__main__":
215
  else "./src/4q.jpg"
216
  ),
217
  show_label=False,
218
- show_download_button=False,
219
- show_fullscreen_button=False,
220
- show_share_button=False,
221
  )
222
  v_radio = gr.Radio(
223
  [_L("Low"), _L("High")],
@@ -283,7 +278,11 @@ if __name__ == "__main__":
283
  save_file = gr.File(label=_L("Download template"))
284
 
285
  with gr.Column():
286
- wav_audio = gr.Audio(label=_L("Audio"), type="filepath")
 
 
 
 
287
  with gr.Accordion(label=_L("Feedback"), open=False):
288
  fdb_radio = gr.Radio(
289
  ["Q1", "Q2", "Q3", "Q4"],
@@ -293,7 +292,7 @@ if __name__ == "__main__":
293
  )
294
  fdb_btn = gr.Button(_L("Submit"))
295
 
296
- status_bar = gr.Textbox(label=_L("Status"), show_copy_button=True)
297
  with gr.Row():
298
  mid_file = gr.File(label=_L("Download MIDI"), min_width=80)
299
  pdf_file = gr.File(label=_L("Download PDF score"), min_width=80)
@@ -301,11 +300,12 @@ if __name__ == "__main__":
301
  mxl_file = gr.File(label=_L("Download MXL"), min_width=80)
302
 
303
  with gr.Row():
304
- abc_txt = gr.TextArea(
305
- label=_L("ABC notation"),
306
- show_copy_button=True,
 
 
307
  )
308
- staff_img = gr.Image(label=_L("Staff"), type="filepath")
309
 
310
  # actions
311
  gen1_btn.click(
 
162
  gr.Video(
163
  "./demo.mp4" if EN_US else "./src/tutorial.mp4",
164
  label=_L("Video demo"),
 
 
165
  )
166
  gr.Markdown(
167
  f"## {_L('Cite')}"
 
213
  else "./src/4q.jpg"
214
  ),
215
  show_label=False,
 
 
 
216
  )
217
  v_radio = gr.Radio(
218
  [_L("Low"), _L("High")],
 
278
  save_file = gr.File(label=_L("Download template"))
279
 
280
  with gr.Column():
281
+ wav_audio = gr.Audio(
282
+ label=_L("Audio"),
283
+ type="filepath",
284
+ buttons=["download"],
285
+ )
286
  with gr.Accordion(label=_L("Feedback"), open=False):
287
  fdb_radio = gr.Radio(
288
  ["Q1", "Q2", "Q3", "Q4"],
 
292
  )
293
  fdb_btn = gr.Button(_L("Submit"))
294
 
295
+ status_bar = gr.Textbox(label=_L("Status"), buttons=["copy"])
296
  with gr.Row():
297
  mid_file = gr.File(label=_L("Download MIDI"), min_width=80)
298
  pdf_file = gr.File(label=_L("Download PDF score"), min_width=80)
 
300
  mxl_file = gr.File(label=_L("Download MXL"), min_width=80)
301
 
302
  with gr.Row():
303
+ abc_txt = gr.TextArea(label=_L("ABC notation"), buttons=["copy"])
304
+ staff_img = gr.Image(
305
+ label=_L("Staff"),
306
+ type="filepath",
307
+ buttons=["fullscreen", "download"],
308
  )
 
309
 
310
  # actions
311
  gen1_btn.click(
generate.py CHANGED
@@ -113,7 +113,7 @@ def generate_music(
113
  )
114
  model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
115
  checkpoint = torch.load(weights, map_location=DEVICE)
116
- model.load_state_dict(checkpoint["model"])
117
  model = model.to(DEVICE)
118
  model.eval()
119
  prompt = f"A:{emo}\n"
 
113
  )
114
  model = TunesFormer(patch_config, char_config, share_weights=SHARE_WEIGHTS)
115
  checkpoint = torch.load(weights, map_location=DEVICE)
116
+ model.load_state_dict(checkpoint["model"], strict=False)
117
  model = model.to(DEVICE)
118
  model.eval()
119
  prompt = f"A:{emo}\n"
model.py CHANGED
@@ -66,10 +66,8 @@ class Patchilizer:
66
  """
67
  lines = unidecode(abc_code).split("\n")
68
  lines = list(filter(None, lines)) # remove empty lines
69
-
70
  body = ""
71
  patches = []
72
-
73
  for line in lines:
74
  if len(line) > 1 and (
75
  (line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
@@ -129,7 +127,6 @@ class PatchLevelDecoder(PreTrainedModel):
129
  patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
130
  patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
131
  patches = self.patch_embedding(patches.to(self.device))
132
-
133
  return self.base(inputs_embeds=patches)
134
 
135
 
@@ -161,11 +158,9 @@ class CharLevelDecoder(PreTrainedModel):
161
  # preparing the labels for model training
162
  target_masks = target_patches == self.pad_token_id
163
  labels = target_patches.clone().masked_fill_(target_masks, -100)
164
-
165
  # masking the labels for model training
166
  target_masks = torch.ones_like(labels)
167
  target_masks = target_masks.masked_fill_(labels == -100, 0)
168
-
169
  # select patches
170
  if (
171
  patch_sampling_batch_size != 0
@@ -174,7 +169,6 @@ class CharLevelDecoder(PreTrainedModel):
174
  indices = list(range(len(target_patches)))
175
  random.shuffle(indices)
176
  selected_indices = sorted(indices[:patch_sampling_batch_size])
177
-
178
  target_patches = target_patches[selected_indices, :]
179
  target_masks = target_masks[selected_indices, :]
180
  encoded_patches = encoded_patches[selected_indices, :]
@@ -184,12 +178,10 @@ class CharLevelDecoder(PreTrainedModel):
184
  inputs_embeds = torch.nn.functional.embedding(
185
  target_patches, self.base.transformer.wte.weight
186
  )
187
-
188
  # concatenate the encoded patches with the input embeddings
189
  inputs_embeds = torch.cat(
190
  (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
191
  )
192
-
193
  return self.base(
194
  inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
195
  )
@@ -203,20 +195,14 @@ class CharLevelDecoder(PreTrainedModel):
203
  """
204
  encoded_patch = encoded_patch.reshape(1, 1, -1)
205
  tokens = tokens.reshape(1, -1)
206
-
207
  # Get input embeddings
208
  tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
209
-
210
  # Concatenate the encoded patch with the input embeddings
211
  tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
212
-
213
  # Get output from model
214
  outputs = self.base(inputs_embeds=tokens)
215
-
216
  # Get probabilities of next token
217
- probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
218
-
219
- return probs
220
 
221
 
222
  class TunesFormer(PreTrainedModel):
@@ -235,14 +221,11 @@ class TunesFormer(PreTrainedModel):
235
  max_layers = max(
236
  encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
237
  )
238
-
239
  max_context_size = max(encoder_config.max_length, decoder_config.max_length)
240
-
241
  max_position_embeddings = max(
242
  encoder_config.max_position_embeddings,
243
  decoder_config.max_position_embeddings,
244
  )
245
-
246
  encoder_config.num_hidden_layers = max_layers
247
  encoder_config.max_length = max_context_size
248
  encoder_config.max_position_embeddings = max_position_embeddings
@@ -252,7 +235,6 @@ class TunesFormer(PreTrainedModel):
252
 
253
  self.patch_level_decoder = PatchLevelDecoder(encoder_config)
254
  self.char_level_decoder = CharLevelDecoder(decoder_config)
255
-
256
  if share_weights:
257
  self.patch_level_decoder.base = self.char_level_decoder.base.transformer
258
 
@@ -268,13 +250,20 @@ class TunesFormer(PreTrainedModel):
268
  """
269
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
270
  encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
271
-
272
  return self.char_level_decoder(
273
  encoded_patches.squeeze(0)[:-1, :],
274
  patches.squeeze(0)[1:, :],
275
  patch_sampling_batch_size,
276
  )
277
 
 
 
 
 
 
 
 
 
278
  def generate(
279
  self,
280
  patches: torch.Tensor,
@@ -291,13 +280,11 @@ class TunesFormer(PreTrainedModel):
291
  """
292
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
293
  encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
294
-
295
  if tokens == None:
296
  tokens = torch.tensor([self.bos_token_id], device=self.device)
297
 
298
  generated_patch = []
299
  random.seed(seed)
300
-
301
  while True:
302
  if seed != None:
303
  n_seed = random.randint(0, 1000000)
@@ -312,12 +299,13 @@ class TunesFormer(PreTrainedModel):
312
  .detach()
313
  .numpy()
314
  )
315
-
316
  prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
317
  prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
318
-
319
- token = temperature_sampling(prob, temperature=temperature, seed=n_seed)
320
-
 
 
321
  generated_patch.append(token)
322
  if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
323
  break
@@ -333,7 +321,6 @@ class TunesFormer(PreTrainedModel):
333
  class PatchilizedData(Dataset):
334
  def __init__(self, items, patchilizer):
335
  self.texts = []
336
-
337
  for item in tqdm(items):
338
  text = item["control code"] + "\n".join(
339
  item["abc notation"].split("\n")[1:]
 
66
  """
67
  lines = unidecode(abc_code).split("\n")
68
  lines = list(filter(None, lines)) # remove empty lines
 
69
  body = ""
70
  patches = []
 
71
  for line in lines:
72
  if len(line) > 1 and (
73
  (line[0].isalpha() and line[1] == ":") or line.startswith("%%score")
 
127
  patches = torch.nn.functional.one_hot(patches, num_classes=128).float()
128
  patches = patches.reshape(len(patches), -1, PATCH_SIZE * 128)
129
  patches = self.patch_embedding(patches.to(self.device))
 
130
  return self.base(inputs_embeds=patches)
131
 
132
 
 
158
  # preparing the labels for model training
159
  target_masks = target_patches == self.pad_token_id
160
  labels = target_patches.clone().masked_fill_(target_masks, -100)
 
161
  # masking the labels for model training
162
  target_masks = torch.ones_like(labels)
163
  target_masks = target_masks.masked_fill_(labels == -100, 0)
 
164
  # select patches
165
  if (
166
  patch_sampling_batch_size != 0
 
169
  indices = list(range(len(target_patches)))
170
  random.shuffle(indices)
171
  selected_indices = sorted(indices[:patch_sampling_batch_size])
 
172
  target_patches = target_patches[selected_indices, :]
173
  target_masks = target_masks[selected_indices, :]
174
  encoded_patches = encoded_patches[selected_indices, :]
 
178
  inputs_embeds = torch.nn.functional.embedding(
179
  target_patches, self.base.transformer.wte.weight
180
  )
 
181
  # concatenate the encoded patches with the input embeddings
182
  inputs_embeds = torch.cat(
183
  (encoded_patches.unsqueeze(1), inputs_embeds[:, 1:, :]), dim=1
184
  )
 
185
  return self.base(
186
  inputs_embeds=inputs_embeds, attention_mask=target_masks, labels=labels
187
  )
 
195
  """
196
  encoded_patch = encoded_patch.reshape(1, 1, -1)
197
  tokens = tokens.reshape(1, -1)
 
198
  # Get input embeddings
199
  tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
 
200
  # Concatenate the encoded patch with the input embeddings
201
  tokens = torch.cat((encoded_patch, tokens[:, 1:, :]), dim=1)
 
202
  # Get output from model
203
  outputs = self.base(inputs_embeds=tokens)
 
204
  # Get probabilities of next token
205
+ return torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
 
 
206
 
207
 
208
  class TunesFormer(PreTrainedModel):
 
221
  max_layers = max(
222
  encoder_config.num_hidden_layers, decoder_config.num_hidden_layers
223
  )
 
224
  max_context_size = max(encoder_config.max_length, decoder_config.max_length)
 
225
  max_position_embeddings = max(
226
  encoder_config.max_position_embeddings,
227
  decoder_config.max_position_embeddings,
228
  )
 
229
  encoder_config.num_hidden_layers = max_layers
230
  encoder_config.max_length = max_context_size
231
  encoder_config.max_position_embeddings = max_position_embeddings
 
235
 
236
  self.patch_level_decoder = PatchLevelDecoder(encoder_config)
237
  self.char_level_decoder = CharLevelDecoder(decoder_config)
 
238
  if share_weights:
239
  self.patch_level_decoder.base = self.char_level_decoder.base.transformer
240
 
 
250
  """
251
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
252
  encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
 
253
  return self.char_level_decoder(
254
  encoded_patches.squeeze(0)[:-1, :],
255
  patches.squeeze(0)[1:, :],
256
  patch_sampling_batch_size,
257
  )
258
 
259
+ def norm(self, prob):
260
+ prob = [float(x) for x in prob]
261
+ s = sum(prob)
262
+ if s == 0:
263
+ raise ValueError("ε…¨ι›Άζ¦‚ηŽ‡")
264
+
265
+ return [x / s for x in prob]
266
+
267
  def generate(
268
  self,
269
  patches: torch.Tensor,
 
280
  """
281
  patches = patches.reshape(len(patches), -1, PATCH_SIZE)
282
  encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"]
 
283
  if tokens == None:
284
  tokens = torch.tensor([self.bos_token_id], device=self.device)
285
 
286
  generated_patch = []
287
  random.seed(seed)
 
288
  while True:
289
  if seed != None:
290
  n_seed = random.randint(0, 1000000)
 
299
  .detach()
300
  .numpy()
301
  )
 
302
  prob = top_p_sampling(prob, top_p=top_p, return_probs=True)
303
  prob = top_k_sampling(prob, top_k=top_k, return_probs=True)
304
+ token = temperature_sampling(
305
+ self.norm(prob),
306
+ temperature=temperature,
307
+ seed=n_seed,
308
+ )
309
  generated_patch.append(token)
310
  if token == self.eos_token_id or len(tokens) >= PATCH_SIZE - 1:
311
  break
 
321
  class PatchilizedData(Dataset):
322
  def __init__(self, items, patchilizer):
323
  self.texts = []
 
324
  for item in tqdm(items):
325
  text = item["control code"] + "\n".join(
326
  item["abc notation"].split("\n")[1:]
utils.py CHANGED
@@ -5,19 +5,27 @@ import torch
5
  import warnings
6
  import requests
7
  import subprocess
8
- import modelscope
9
- import huggingface_hub
10
  from tqdm import tqdm
11
 
12
  warnings.filterwarnings("ignore")
13
 
14
  TEMP_DIR = "./__pycache__"
15
  EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
16
- WEIGHTS_DIR = (
17
- huggingface_hub.snapshot_download("monetjoe/EMelodyGen", cache_dir=TEMP_DIR)
18
- if EN_US
19
- else modelscope.snapshot_download("monetjoe/EMelodyGen", cache_dir=TEMP_DIR)
20
- )
 
 
 
 
 
 
 
 
 
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  PATCH_LENGTH = 128 # Patch Length
23
  PATCH_SIZE = 32 # Patch Size
 
5
  import warnings
6
  import requests
7
  import subprocess
 
 
8
  from tqdm import tqdm
9
 
10
  warnings.filterwarnings("ignore")
11
 
12
  TEMP_DIR = "./__pycache__"
13
  EN_US = os.getenv("LANG") != "zh_CN.UTF-8"
14
+ if EN_US:
15
+ import huggingface_hub
16
+
17
+ WEIGHTS_DIR = huggingface_hub.snapshot_download(
18
+ "monetjoe/EMelodyGen",
19
+ cache_dir=TEMP_DIR,
20
+ )
21
+ else:
22
+ import modelscope
23
+
24
+ WEIGHTS_DIR = modelscope.snapshot_download(
25
+ "monetjoe/EMelodyGen",
26
+ cache_dir=TEMP_DIR,
27
+ )
28
+
29
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  PATCH_LENGTH = 128 # Patch Length
31
  PATCH_SIZE = 32 # Patch Size