projectlosangeles commited on
Commit
ac3d13d
·
verified ·
1 Parent(s): 913cbef

Upload 3 files

Browse files
Files changed (3) hide show
  1. TMIDIX.py +1462 -32
  2. midi_to_colab_audio.py +775 -228
  3. x_transformer_2_3_1.py +412 -2
TMIDIX.py CHANGED
@@ -51,7 +51,7 @@ r'''############################################################################
51
 
52
  ###################################################################################
53
 
54
- __version__ = "25.7.8"
55
 
56
  print('=' * 70)
57
  print('TMIDIX Python module')
@@ -1485,10 +1485,13 @@ import multiprocessing
1485
 
1486
  from itertools import zip_longest
1487
  from itertools import groupby
 
 
1488
 
1489
  from collections import Counter
1490
  from collections import defaultdict
1491
  from collections import OrderedDict
 
1492
 
1493
  from operator import itemgetter
1494
 
@@ -1498,6 +1501,9 @@ from difflib import SequenceMatcher as SM
1498
 
1499
  import statistics
1500
  import math
 
 
 
1501
 
1502
  import matplotlib.pyplot as plt
1503
 
@@ -3903,7 +3909,8 @@ def chordify_score(score,
3903
 
3904
  def fix_monophonic_score_durations(monophonic_score,
3905
  min_notes_gap=1,
3906
- min_notes_dur=1
 
3907
  ):
3908
 
3909
  fixed_score = []
@@ -3918,7 +3925,11 @@ def fix_monophonic_score_durations(monophonic_score,
3918
  if note[1]+note[2] >= nmt:
3919
  note_dur = max(1, nmt-note[1]-min_notes_gap)
3920
  else:
3921
- note_dur = note[2]
 
 
 
 
3922
 
3923
  new_note = [note[0], note[1], note_dur] + note[3:]
3924
 
@@ -3936,9 +3947,13 @@ def fix_monophonic_score_durations(monophonic_score,
3936
  nmt = monophonic_score[i+1][0]
3937
 
3938
  if note[0]+note[1] >= nmt:
3939
- note_dur = max(1, nmt-note[0]-min_notes_gap)
3940
  else:
3941
- note_dur = note[1]
 
 
 
 
3942
 
3943
  new_note = [note[0], note_dur] + note[2:]
3944
 
@@ -3952,8 +3967,6 @@ def fix_monophonic_score_durations(monophonic_score,
3952
 
3953
  ###################################################################################
3954
 
3955
- from itertools import product
3956
-
3957
  ALL_CHORDS = [[0], [7], [5], [9], [2], [4], [11], [10], [8], [6], [3], [1], [0, 9], [2, 5],
3958
  [4, 7], [7, 10], [2, 11], [0, 3], [6, 9], [1, 4], [8, 11], [5, 8], [1, 10],
3959
  [3, 6], [0, 4], [5, 9], [7, 11], [0, 7], [0, 5], [2, 10], [2, 7], [2, 9],
@@ -7128,7 +7141,8 @@ def escore_notes_to_binary_matrix(escore_notes,
7128
  channel=0,
7129
  patch=0,
7130
  flip_matrix=False,
7131
- reverse_matrix=False
 
7132
  ):
7133
 
7134
  escore = [e for e in escore_notes if e[3] == channel and e[6] == patch]
@@ -7152,14 +7166,17 @@ def escore_notes_to_binary_matrix(escore_notes,
7152
  duration = max(1, duration)
7153
  chan = max(0, min(15, chan))
7154
  pitch = max(0, min(127, pitch))
7155
- velocity = max(0, min(127, velocity))
7156
  pat = max(0, min(128, pat))
7157
 
7158
  if channel == chan and patch == pat:
7159
 
7160
  for t in range(time, min(time + duration, time_range)):
7161
-
7162
- escore_matrix[t][pitch] = 1
 
 
 
7163
 
7164
  if flip_matrix:
7165
 
@@ -7183,7 +7200,8 @@ def escore_notes_to_binary_matrix(escore_notes,
7183
  def binary_matrix_to_original_escore_notes(binary_matrix,
7184
  channel=0,
7185
  patch=0,
7186
- velocity=-1
 
7187
  ):
7188
 
7189
  result = []
@@ -7222,8 +7240,11 @@ def binary_matrix_to_original_escore_notes(binary_matrix,
7222
 
7223
  for r in result:
7224
 
7225
- if velocity == -1:
7226
- vel = max(40, r[2])
 
 
 
7227
 
7228
  original_escore_notes.append(['note', r[0], r[1], channel, r[2], vel, patch])
7229
 
@@ -8048,7 +8069,7 @@ def solo_piano_escore_notes(escore_notes,
8048
  keep_drums=False,
8049
  ):
8050
 
8051
- cscore = chordify_score([1000, escore_notes])
8052
 
8053
  sp_escore_notes = []
8054
 
@@ -9720,7 +9741,14 @@ def escore_notes_to_text_description(escore_notes,
9720
  song_name='',
9721
  artist_name='',
9722
  timings_divider=16,
 
 
9723
  ):
 
 
 
 
 
9724
 
9725
  #==============================================================================
9726
 
@@ -9734,6 +9762,9 @@ def escore_notes_to_text_description(escore_notes,
9734
 
9735
  elif song_time_min >= 2.5:
9736
  song_length = 'long'
 
 
 
9737
 
9738
  #==============================================================================
9739
 
@@ -9745,18 +9776,25 @@ def escore_notes_to_text_description(escore_notes,
9745
  if len(escore_times) == len(set(escore_times)):
9746
  comp_type = 'monophonic melody'
9747
  ctype = 'melody'
 
9748
 
9749
  elif len(escore_times) >= len(set(escore_times)) and 1 in Counter(escore_times).values():
9750
  comp_type = 'melody and accompaniment'
9751
  ctype = 'song'
 
9752
 
9753
  elif len(escore_times) >= len(set(escore_times)) and 1 not in Counter(escore_times).values():
9754
  comp_type = 'accompaniment'
9755
  ctype = 'song'
 
9756
 
9757
  else:
9758
  comp_type = 'drum track'
9759
  ctype = 'drum track'
 
 
 
 
9760
 
9761
  #==============================================================================
9762
 
@@ -9771,6 +9809,13 @@ def escore_notes_to_text_description(escore_notes,
9771
  nd_patches_counts = Counter([p for p in all_patches if p < 128]).most_common()
9772
 
9773
  dominant_instrument = alpha_str(Number2patch[nd_patches_counts[0][0]])
 
 
 
 
 
 
 
9774
 
9775
  if 128 in patches:
9776
  drums_present = True
@@ -9778,9 +9823,16 @@ def escore_notes_to_text_description(escore_notes,
9778
  drums_pitches = [e[4] for e in escore_notes if e[3] == 9]
9779
 
9780
  most_common_drums = [alpha_str(Notenum2percussion[p[0]]) for p in Counter(drums_pitches).most_common(3) if p[0] in Notenum2percussion]
 
 
 
9781
 
9782
  else:
9783
  drums_present = False
 
 
 
 
9784
 
9785
  #==============================================================================
9786
 
@@ -9790,60 +9842,111 @@ def escore_notes_to_text_description(escore_notes,
9790
 
9791
  if pitches:
9792
  key = SEMITONES[statistics.mode(pitches) % 12]
 
 
 
 
 
 
 
9793
 
9794
  #==============================================================================
9795
 
9796
  scale = ''
9797
  mood = ''
9798
 
 
 
 
 
 
9799
  if pitches:
9800
 
9801
  result = escore_notes_scale(escore_notes)
9802
 
9803
  scale = result[0]
9804
  mood = result[1].split(' ')[0].lower()
 
 
 
 
 
 
 
9805
 
9806
  #==============================================================================
9807
-
 
 
 
 
 
 
 
 
 
 
9808
  if pitches:
9809
 
9810
  escore_averages = escore_notes_averages(escore_notes, return_ptcs_and_vels=True)
9811
 
9812
  if escore_averages[0] < (128 / timings_divider):
9813
  rythm = 'fast'
 
9814
 
9815
  elif (128 / timings_divider) <= escore_averages[0] <= (192 / timings_divider):
9816
  rythm = 'average'
 
9817
 
9818
  elif escore_averages[0] > (192 / timings_divider):
9819
  rythm = 'slow'
 
9820
 
9821
  if escore_averages[1] < (256 / timings_divider):
9822
  tempo = 'fast'
 
9823
 
9824
  elif (256 / timings_divider) <= escore_averages[1] <= (384 / timings_divider):
9825
  tempo = 'average'
 
9826
 
9827
  elif escore_averages[1] > (384 / timings_divider):
9828
  tempo = 'slow'
 
9829
 
9830
  if escore_averages[2] < 50:
9831
  tone = 'bass'
 
9832
 
9833
  elif 50 <= escore_averages[2] <= 70:
9834
  tone = 'midrange'
 
9835
 
9836
  elif escore_averages[2] > 70:
9837
  tone = 'treble'
 
9838
 
9839
  if escore_averages[3] < 64:
9840
  dynamics = 'quiet'
 
9841
 
9842
  elif 64 <= escore_averages[3] <= 96:
9843
  dynamics = 'average'
 
9844
 
9845
  elif escore_averages[3] > 96:
9846
  dynamics = 'loud'
 
 
 
 
 
 
 
 
 
 
 
9847
 
9848
  #==============================================================================
9849
 
@@ -9851,6 +9954,12 @@ def escore_notes_to_text_description(escore_notes,
9851
 
9852
  lead_melodies = []
9853
  base_melodies = []
 
 
 
 
 
 
9854
 
9855
  if mono_melodies:
9856
 
@@ -9860,15 +9969,19 @@ def escore_notes_to_text_description(escore_notes,
9860
 
9861
  if mel[0] in LEAD_INSTRUMENTS and escore_avgs[3] > 60:
9862
  lead_melodies.append([Number2patch[mel[0]], mel[1]])
 
9863
 
9864
  elif mel[0] in BASE_INSTRUMENTS and escore_avgs[3] <= 60:
9865
  base_melodies.append([Number2patch[mel[0]], mel[1]])
 
9866
 
9867
  if lead_melodies:
9868
  lead_melodies.sort(key=lambda x: x[1], reverse=True)
 
9869
 
9870
  if base_melodies:
9871
  base_melodies.sort(key=lambda x: x[1], reverse=True)
 
9872
 
9873
  #==============================================================================
9874
 
@@ -10055,8 +10168,20 @@ def escore_notes_to_text_description(escore_notes,
10055
  description += '\n'
10056
 
10057
  #==============================================================================
10058
-
10059
- return description
 
 
 
 
 
 
 
 
 
 
 
 
10060
 
10061
  ###################################################################################
10062
 
@@ -11282,21 +11407,27 @@ def escore_notes_core(escore_notes, core_len=128):
11282
 
11283
  ###################################################################################
11284
 
11285
- def multiprocessing_wrapper(function, data_list, verbose=True):
 
 
 
11286
 
11287
- with multiprocessing.Pool() as pool:
11288
-
11289
- results = []
11290
-
11291
- for result in tqdm.tqdm(pool.imap_unordered(function, data_list),
11292
- total=len(data_list),
11293
- disable=not verbose
11294
- ):
11295
-
 
 
11296
  results.append(result)
11297
-
11298
  return results
11299
 
 
11300
  ###################################################################################
11301
 
11302
  def rle_encode_ones(matrix, div_mod=-1):
@@ -11479,9 +11610,10 @@ def create_files_list(datasets_paths=['./'],
11479
 
11480
  for dataset_addr in datasets_paths:
11481
 
11482
- print('=' * 70)
11483
- print('Processing', dataset_addr)
11484
- print('=' * 70)
 
11485
 
11486
  for dirpath, dirnames, filenames in tqdm.tqdm(os.walk(dataset_addr), disable=not verbose):
11487
 
@@ -13604,6 +13736,1304 @@ PERCUSSION_GROUPS = {
13604
 
13605
  ###################################################################################
13606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13607
  print('Module loaded!')
13608
  print('=' * 70)
13609
  print('Enjoy! :)')
 
51
 
52
  ###################################################################################
53
 
54
+ __version__ = "25.12.29"
55
 
56
  print('=' * 70)
57
  print('TMIDIX Python module')
 
1485
 
1486
  from itertools import zip_longest
1487
  from itertools import groupby
1488
+ from itertools import cycle
1489
+ from itertools import product
1490
 
1491
  from collections import Counter
1492
  from collections import defaultdict
1493
  from collections import OrderedDict
1494
+ from collections import deque
1495
 
1496
  from operator import itemgetter
1497
 
 
1501
 
1502
  import statistics
1503
  import math
1504
+ from math import gcd
1505
+
1506
+ from functools import reduce
1507
 
1508
  import matplotlib.pyplot as plt
1509
 
 
3909
 
3910
  def fix_monophonic_score_durations(monophonic_score,
3911
  min_notes_gap=1,
3912
+ min_notes_dur=1,
3913
+ extend_durs=False
3914
  ):
3915
 
3916
  fixed_score = []
 
3925
  if note[1]+note[2] >= nmt:
3926
  note_dur = max(1, nmt-note[1]-min_notes_gap)
3927
  else:
3928
+ if extend_durs:
3929
+ note_dur = max(1, nmt-note[1]-min_notes_gap)
3930
+
3931
+ else:
3932
+ note_dur = note[2]
3933
 
3934
  new_note = [note[0], note[1], note_dur] + note[3:]
3935
 
 
3947
  nmt = monophonic_score[i+1][0]
3948
 
3949
  if note[0]+note[1] >= nmt:
3950
+ note_dur = max(1, nmt-note[0]-min_notes_gap)
3951
  else:
3952
+ if extend_durs:
3953
+ note_dur = max(1, nmt-note[0]-min_notes_gap)
3954
+
3955
+ else:
3956
+ note_dur = note[1]
3957
 
3958
  new_note = [note[0], note_dur] + note[2:]
3959
 
 
3967
 
3968
  ###################################################################################
3969
 
 
 
3970
  ALL_CHORDS = [[0], [7], [5], [9], [2], [4], [11], [10], [8], [6], [3], [1], [0, 9], [2, 5],
3971
  [4, 7], [7, 10], [2, 11], [0, 3], [6, 9], [1, 4], [8, 11], [5, 8], [1, 10],
3972
  [3, 6], [0, 4], [5, 9], [7, 11], [0, 7], [0, 5], [2, 10], [2, 7], [2, 9],
 
7141
  channel=0,
7142
  patch=0,
7143
  flip_matrix=False,
7144
+ reverse_matrix=False,
7145
+ encode_velocities=False
7146
  ):
7147
 
7148
  escore = [e for e in escore_notes if e[3] == channel and e[6] == patch]
 
7166
  duration = max(1, duration)
7167
  chan = max(0, min(15, chan))
7168
  pitch = max(0, min(127, pitch))
7169
+ velocity = max(1, min(127, velocity))
7170
  pat = max(0, min(128, pat))
7171
 
7172
  if channel == chan and patch == pat:
7173
 
7174
  for t in range(time, min(time + duration, time_range)):
7175
+ if encode_velocities:
7176
+ escore_matrix[t][pitch] = velocity
7177
+
7178
+ else:
7179
+ escore_matrix[t][pitch] = 1
7180
 
7181
  if flip_matrix:
7182
 
 
7200
  def binary_matrix_to_original_escore_notes(binary_matrix,
7201
  channel=0,
7202
  patch=0,
7203
+ velocity=-1,
7204
+ decode_velocities=False
7205
  ):
7206
 
7207
  result = []
 
7240
 
7241
  for r in result:
7242
 
7243
+ if velocity == -1 and not decode_velocities:
7244
+ vel = max(40, r[2])
7245
+
7246
+ if decode_velocities:
7247
+ vel = r[3]
7248
 
7249
  original_escore_notes.append(['note', r[0], r[1], channel, r[2], vel, patch])
7250
 
 
8069
  keep_drums=False,
8070
  ):
8071
 
8072
+ cscore = chordify_score([1000, copy.deepcopy(escore_notes)])
8073
 
8074
  sp_escore_notes = []
8075
 
 
9741
  song_name='',
9742
  artist_name='',
9743
  timings_divider=16,
9744
+ return_feat_dict=False,
9745
+ return_feat_dict_vals=False
9746
  ):
9747
+
9748
+ #==============================================================================
9749
+
9750
+ feat_dict = {}
9751
+ feat_dict_vals = {}
9752
 
9753
  #==============================================================================
9754
 
 
9762
 
9763
  elif song_time_min >= 2.5:
9764
  song_length = 'long'
9765
+
9766
+ feat_dict['song_len'] = song_length.capitalize()
9767
+ feat_dict_vals['song_len'] = song_time_min
9768
 
9769
  #==============================================================================
9770
 
 
9776
  if len(escore_times) == len(set(escore_times)):
9777
  comp_type = 'monophonic melody'
9778
  ctype = 'melody'
9779
+ ctv = 0
9780
 
9781
  elif len(escore_times) >= len(set(escore_times)) and 1 in Counter(escore_times).values():
9782
  comp_type = 'melody and accompaniment'
9783
  ctype = 'song'
9784
+ ctv = 1
9785
 
9786
  elif len(escore_times) >= len(set(escore_times)) and 1 not in Counter(escore_times).values():
9787
  comp_type = 'accompaniment'
9788
  ctype = 'song'
9789
+ ctv = 2
9790
 
9791
  else:
9792
  comp_type = 'drum track'
9793
  ctype = 'drum track'
9794
+ ctv = 3
9795
+
9796
+ feat_dict['song_type'] = comp_type.capitalize()
9797
+ feat_dict_vals['song_type'] = ctv
9798
 
9799
  #==============================================================================
9800
 
 
9809
  nd_patches_counts = Counter([p for p in all_patches if p < 128]).most_common()
9810
 
9811
  dominant_instrument = alpha_str(Number2patch[nd_patches_counts[0][0]])
9812
+
9813
+ feat_dict['most_com_instr'] = instruments
9814
+ feat_dict_vals['most_com_instr'] = [p for p in patches if p < 128]
9815
+
9816
+ else:
9817
+ feat_dict['most_com_instr'] = None
9818
+ feat_dict_vals['most_com_instr'] = []
9819
 
9820
  if 128 in patches:
9821
  drums_present = True
 
9823
  drums_pitches = [e[4] for e in escore_notes if e[3] == 9]
9824
 
9825
  most_common_drums = [alpha_str(Notenum2percussion[p[0]]) for p in Counter(drums_pitches).most_common(3) if p[0] in Notenum2percussion]
9826
+
9827
+ feat_dict['most_com_drums'] = most_common_drums
9828
+ feat_dict_vals['most_com_drums'] = [p[0] for p in Counter(drums_pitches).most_common(3)]
9829
 
9830
  else:
9831
  drums_present = False
9832
+
9833
+ feat_dict['most_com_drums'] = None
9834
+
9835
+ feat_dict_vals['most_com_drums'] = []
9836
 
9837
  #==============================================================================
9838
 
 
9842
 
9843
  if pitches:
9844
  key = SEMITONES[statistics.mode(pitches) % 12]
9845
+
9846
+ feat_dict['key'] = key.title()
9847
+ feat_dict_vals['key'] = statistics.mode(pitches) % 12
9848
+
9849
+ else:
9850
+ feat_dict['key'] = None
9851
+ feat_dict_vals['key'] = -1
9852
 
9853
  #==============================================================================
9854
 
9855
  scale = ''
9856
  mood = ''
9857
 
9858
+ feat_dict['scale'] = None
9859
+ feat_dict['mood'] = None
9860
+ feat_dict_vals['scale'] = -1
9861
+ feat_dict_vals['mood'] = -1
9862
+
9863
  if pitches:
9864
 
9865
  result = escore_notes_scale(escore_notes)
9866
 
9867
  scale = result[0]
9868
  mood = result[1].split(' ')[0].lower()
9869
+
9870
+ feat_dict['scale'] = scale.title()
9871
+ feat_dict['mood'] = mood.title()
9872
+
9873
+ res = escore_notes_scale(escore_notes, return_scale_indexes=True)
9874
+ feat_dict_vals['scale'] = res[0]
9875
+ feat_dict_vals['mood'] = res[1]
9876
 
9877
  #==============================================================================
9878
+
9879
+ feat_dict['rythm'] = None
9880
+ feat_dict['tempo'] = None
9881
+ feat_dict['tone'] = None
9882
+ feat_dict['dynamics'] = None
9883
+
9884
+ feat_dict_vals['rythm'] = -1
9885
+ feat_dict_vals['tempo'] = -1
9886
+ feat_dict_vals['tone'] = -1
9887
+ feat_dict_vals['dynamics'] = -1
9888
+
9889
  if pitches:
9890
 
9891
  escore_averages = escore_notes_averages(escore_notes, return_ptcs_and_vels=True)
9892
 
9893
  if escore_averages[0] < (128 / timings_divider):
9894
  rythm = 'fast'
9895
+ ryv = 0
9896
 
9897
  elif (128 / timings_divider) <= escore_averages[0] <= (192 / timings_divider):
9898
  rythm = 'average'
9899
+ ryv = 1
9900
 
9901
  elif escore_averages[0] > (192 / timings_divider):
9902
  rythm = 'slow'
9903
+ ryv = 2
9904
 
9905
  if escore_averages[1] < (256 / timings_divider):
9906
  tempo = 'fast'
9907
+ tev = 0
9908
 
9909
  elif (256 / timings_divider) <= escore_averages[1] <= (384 / timings_divider):
9910
  tempo = 'average'
9911
+ tev = 1
9912
 
9913
  elif escore_averages[1] > (384 / timings_divider):
9914
  tempo = 'slow'
9915
+ tev = 2
9916
 
9917
  if escore_averages[2] < 50:
9918
  tone = 'bass'
9919
+ tov = 0
9920
 
9921
  elif 50 <= escore_averages[2] <= 70:
9922
  tone = 'midrange'
9923
+ tov = 1
9924
 
9925
  elif escore_averages[2] > 70:
9926
  tone = 'treble'
9927
+ tov = 2
9928
 
9929
  if escore_averages[3] < 64:
9930
  dynamics = 'quiet'
9931
+ dyn = 0
9932
 
9933
  elif 64 <= escore_averages[3] <= 96:
9934
  dynamics = 'average'
9935
+ dyn = 1
9936
 
9937
  elif escore_averages[3] > 96:
9938
  dynamics = 'loud'
9939
+ dyn = 2
9940
+
9941
+ feat_dict['rythm'] = rythm.title()
9942
+ feat_dict['tempo'] = tempo.title()
9943
+ feat_dict['tone'] = tone.title()
9944
+ feat_dict['dynamics'] = dynamics.title()
9945
+
9946
+ feat_dict_vals['rythm'] = ryv
9947
+ feat_dict_vals['tempo'] = tev
9948
+ feat_dict_vals['tone'] = tov
9949
+ feat_dict_vals['dynamics'] = dyn
9950
 
9951
  #==============================================================================
9952
 
 
9954
 
9955
  lead_melodies = []
9956
  base_melodies = []
9957
+
9958
+ feat_dict['lead_mono_mels'] = None
9959
+ feat_dict['base_mono_mels'] = None
9960
+
9961
+ feat_dict_vals['lead_mono_mels'] = []
9962
+ feat_dict_vals['base_mono_mels'] = []
9963
 
9964
  if mono_melodies:
9965
 
 
9969
 
9970
  if mel[0] in LEAD_INSTRUMENTS and escore_avgs[3] > 60:
9971
  lead_melodies.append([Number2patch[mel[0]], mel[1]])
9972
+ feat_dict_vals['lead_mono_mels'].append(mel[0])
9973
 
9974
  elif mel[0] in BASE_INSTRUMENTS and escore_avgs[3] <= 60:
9975
  base_melodies.append([Number2patch[mel[0]], mel[1]])
9976
+ feat_dict_vals['base_mono_mels'].append(mel[0])
9977
 
9978
  if lead_melodies:
9979
  lead_melodies.sort(key=lambda x: x[1], reverse=True)
9980
+ feat_dict['lead_mono_mels'] = lead_melodies
9981
 
9982
  if base_melodies:
9983
  base_melodies.sort(key=lambda x: x[1], reverse=True)
9984
+ feat_dict['base_mono_mels'] = base_melodies
9985
 
9986
  #==============================================================================
9987
 
 
10168
  description += '\n'
10169
 
10170
  #==============================================================================
10171
+
10172
+ final_feat_dict = []
10173
+
10174
+ if return_feat_dict:
10175
+ final_feat_dict.append(feat_dict)
10176
+
10177
+ if return_feat_dict_vals:
10178
+ final_feat_dict.append(feat_dict_vals)
10179
+
10180
+ if return_feat_dict or return_feat_dict_vals:
10181
+ return final_feat_dict
10182
+
10183
+ else:
10184
+ return description
10185
 
10186
  ###################################################################################
10187
 
 
11407
 
11408
  ###################################################################################
11409
 
11410
+ def multiprocessing_wrapper(function,
11411
+ data_list,
11412
+ num_workers=None,
11413
+ verbose=True):
11414
 
11415
+ if num_workers is None:
11416
+ num_workers = multiprocessing.cpu_count()
11417
+
11418
+ results = []
11419
+
11420
+ with multiprocessing.Pool(processes=num_workers) as pool:
11421
+ for result in tqdm.tqdm(
11422
+ pool.imap(function, data_list),
11423
+ total=len(data_list),
11424
+ disable=not verbose
11425
+ ):
11426
  results.append(result)
11427
+
11428
  return results
11429
 
11430
+
11431
  ###################################################################################
11432
 
11433
  def rle_encode_ones(matrix, div_mod=-1):
 
11610
 
11611
  for dataset_addr in datasets_paths:
11612
 
11613
+ if verbose:
11614
+ print('=' * 70)
11615
+ print('Processing', dataset_addr)
11616
+ print('=' * 70)
11617
 
11618
  for dirpath, dirnames, filenames in tqdm.tqdm(os.walk(dataset_addr), disable=not verbose):
11619
 
 
13736
 
13737
  ###################################################################################
13738
 
13739
+ def escore_notes_to_expanded_binary_matrix(escore_notes,
13740
+ channel=0,
13741
+ patch=0,
13742
+ flip_matrix=False,
13743
+ reverse_matrix=False,
13744
+ encode_velocities=True
13745
+ ):
13746
+
13747
+ escore = [e for e in escore_notes if e[3] == channel and e[6] == patch]
13748
+
13749
+ if escore:
13750
+ last_time = escore[-1][1]
13751
+ last_notes = [e for e in escore if e[1] == last_time]
13752
+ max_last_dur = max([e[2] for e in last_notes])
13753
+
13754
+ time_range = last_time+max_last_dur
13755
+
13756
+ escore_matrix = []
13757
+
13758
+ escore_matrix = [[(0, 0)] * 128 for _ in range(time_range)]
13759
+
13760
+ for note in escore:
13761
+
13762
+ etype, time, duration, chan, pitch, velocity, pat = note
13763
+
13764
+ time = max(0, time)
13765
+ duration = max(1, duration)
13766
+ chan = max(0, min(15, chan))
13767
+ pitch = max(0, min(127, pitch))
13768
+ velocity = max(1, min(127, velocity))
13769
+ pat = max(0, min(128, pat))
13770
+
13771
+ if channel == chan and patch == pat:
13772
+
13773
+ count = 0
13774
+
13775
+ for t in range(time, min(time + duration, time_range)):
13776
+ if encode_velocities:
13777
+ escore_matrix[t][pitch] = velocity, count
13778
+
13779
+ else:
13780
+ escore_matrix[t][pitch] = 1, count
13781
+ count += 1
13782
+
13783
+ if flip_matrix:
13784
+
13785
+ temp_matrix = []
13786
+
13787
+ for m in escore_matrix:
13788
+ temp_matrix.append(m[::-1])
13789
+
13790
+ escore_matrix = temp_matrix
13791
+
13792
+ if reverse_matrix:
13793
+ escore_matrix = escore_matrix[::-1]
13794
+
13795
+ return escore_matrix
13796
+
13797
+ else:
13798
+ return None
13799
+
13800
+ ###################################################################################
13801
+
13802
+ def transpose_list(lst):
13803
+ return [list(row) for row in zip(*lst)]
13804
+
13805
+ ###################################################################################
13806
+
13807
+ def chunk_list(lst, size):
13808
+ return [lst[i:i + size] for i in range(0, len(lst), size)]
13809
+
13810
+ ###################################################################################
13811
+
13812
+ def flip_list_rows(lst):
13813
+ return [row[::-1] for row in lst]
13814
+
13815
+ ###################################################################################
13816
+
13817
+ def flip_list_columns(lst):
13818
+ return lst[::-1]
13819
+
13820
+ ###################################################################################
13821
+
13822
+ def exists(sub, lst):
13823
+ sub_len = len(sub)
13824
+ return any(lst[i:i + sub_len] == sub for i in range(len(lst) - sub_len + 1))
13825
+
13826
+ ###################################################################################
13827
+
13828
+ def exists_noncontig(sub, lst):
13829
+ it = iter(lst)
13830
+ return all(x in it for x in sub)
13831
+
13832
+ ###################################################################################
13833
+
13834
+ def exists_ratio(sub, lst, ratio):
13835
+ matches = sum(x in set(lst) for x in sub)
13836
+ return matches / len(sub) >= ratio
13837
+
13838
+ ###################################################################################
13839
+
13840
+ def top_k_list_value(lst, k, reverse=True):
13841
+ return sorted(lst, reverse=reverse)[k]
13842
+
13843
+ ###################################################################################
13844
+
13845
+ def top_k_list_values(lst, k, reverse=True):
13846
+ return sorted(lst, reverse=reverse)[:k]
13847
+
13848
+ ###################################################################################
13849
+
13850
+ def concat_rows(lst_A, lst_B):
13851
+ return [a + b for a, b in zip(lst_A, lst_B)]
13852
+
13853
+ ###################################################################################
13854
+
13855
+ def concat_cols(lst_A, lst_B):
13856
+ return [[ra + rb for ra, rb in zip(a, b)] for a, b in zip(lst_A, lst_B)]
13857
+
13858
+ ###################################################################################
13859
+
13860
+ def chunk_by_threshold_mode(nums, threshold=0, normalize=False):
13861
+
13862
+ if not nums:
13863
+ return []
13864
+
13865
+ chunks = []
13866
+ chunk = []
13867
+ freq = defaultdict(int)
13868
+ max_freq = 0
13869
+ mode_val = None
13870
+
13871
+ def try_add_and_validate(value):
13872
+
13873
+ nonlocal max_freq, mode_val
13874
+
13875
+ chunk.append(value)
13876
+ freq[value] += 1
13877
+ new_max_freq = max_freq
13878
+ candidate_mode = mode_val
13879
+
13880
+ if freq[value] > new_max_freq:
13881
+ new_max_freq = freq[value]
13882
+ candidate_mode = value
13883
+
13884
+ mode = candidate_mode
13885
+ valid = True
13886
+
13887
+ for v in chunk:
13888
+ if abs(v - mode) > threshold:
13889
+ valid = False
13890
+ break
13891
+
13892
+ if not valid:
13893
+
13894
+ chunk.pop()
13895
+ freq[value] -= 1
13896
+ if freq[value] == 0:
13897
+ del freq[value]
13898
+
13899
+ return False
13900
+
13901
+ max_freq = new_max_freq
13902
+ mode_val = mode
13903
+ return True
13904
+
13905
+ for num in nums:
13906
+ if not chunk:
13907
+ chunk.append(num)
13908
+ freq[num] = 1
13909
+ mode_val = num
13910
+ max_freq = 1
13911
+
13912
+ else:
13913
+ if not try_add_and_validate(num):
13914
+ if normalize:
13915
+ normalized_chunk = [mode_val] * len(chunk)
13916
+ chunks.append(normalized_chunk)
13917
+
13918
+ else:
13919
+ chunks.append(chunk[:])
13920
+
13921
+ chunk.clear()
13922
+ freq.clear()
13923
+
13924
+ chunk.append(num)
13925
+ freq[num] = 1
13926
+ mode_val = num
13927
+ max_freq = 1
13928
+
13929
+ if chunk:
13930
+ if normalize:
13931
+ normalized_chunk = [mode_val] * len(chunk)
13932
+ chunks.append(normalized_chunk)
13933
+
13934
+ else:
13935
+ chunks.append(chunk)
13936
+
13937
+ return chunks
13938
+
13939
+ ###################################################################################
13940
+
13941
+ def proportional_adjust(values, target_sum, threshold):
13942
+
13943
+ n = len(values)
13944
+ if n == 0:
13945
+ return []
13946
+
13947
+ locked_idx = [i for i, v in enumerate(values) if v < threshold]
13948
+ adj_idx = [i for i in range(n) if i not in locked_idx]
13949
+
13950
+ locked_sum = sum(values[i] for i in locked_idx)
13951
+ adj_original_sum = sum(values[i] for i in adj_idx)
13952
+ adj_target_sum = target_sum - locked_sum
13953
+
13954
+ def _proportional_scale(idxs, original, target):
13955
+
13956
+ scaled_vals = {i: original[i] * (target / sum(original[i] for i in idxs))
13957
+ if sum(original[i] for i in idxs) > 0 else 0
13958
+ for i in idxs}
13959
+
13960
+ floored = {i: math.floor(scaled_vals[i]) for i in idxs}
13961
+ rem = target - sum(floored.values())
13962
+
13963
+ fracs = sorted(
13964
+ ((scaled_vals[i] - floored[i], i) for i in idxs),
13965
+ key=lambda x: (x[0], -x[1]),
13966
+ reverse=True
13967
+ )
13968
+
13969
+ for _, idx in fracs[:rem]:
13970
+ floored[idx] += 1
13971
+
13972
+ result = original.copy()
13973
+
13974
+ for i in idxs:
13975
+ result[i] = floored[i]
13976
+
13977
+ return result
13978
+
13979
+ if not adj_idx:
13980
+ if locked_sum == target_sum:
13981
+ return values.copy()
13982
+
13983
+ return _proportional_scale(locked_idx, values, target_sum)
13984
+
13985
+ if adj_target_sum < 0:
13986
+ return _proportional_scale(range(n), values, target_sum)
13987
+
13988
+ if adj_original_sum == 0:
13989
+ base = adj_target_sum // len(adj_idx)
13990
+ rem = adj_target_sum - base * len(adj_idx)
13991
+ result = values.copy()
13992
+
13993
+ for j, idx in enumerate(sorted(adj_idx)):
13994
+ increment = base + (1 if j < rem else 0)
13995
+ result[idx] = values[idx] + increment
13996
+
13997
+ return result
13998
+
13999
+ result = values.copy()
14000
+ scaled = {i: values[i] * (adj_target_sum / adj_original_sum) for i in adj_idx}
14001
+ floored = {i: math.floor(scaled[i]) for i in adj_idx}
14002
+ floor_sum = sum(floored.values())
14003
+ rem = adj_target_sum - floor_sum
14004
+
14005
+ fracs = sorted(
14006
+ ((scaled[i] - floored[i], i) for i in adj_idx),
14007
+ key=lambda x: (x[0], -x[1]),
14008
+ reverse=True
14009
+ )
14010
+
14011
+ for _, idx in fracs[:rem]:
14012
+ floored[idx] += 1
14013
+
14014
+ for i in adj_idx:
14015
+ result[i] = floored[i]
14016
+
14017
+ return result
14018
+
14019
+ ###################################################################################
14020
+
14021
+ def advanced_align_escore_notes_to_bars(escore_notes,
14022
+ bar_dtime=200,
14023
+ dtimes_adj_thresh=4,
14024
+ min_dur_gap=0
14025
+ ):
14026
+
14027
+ #========================================================
14028
+
14029
+ escore_notes = recalculate_score_timings(escore_notes)
14030
+
14031
+ cscore = chordify_score([1000, escore_notes])
14032
+
14033
+ #========================================================
14034
+
14035
+ dtimes = [0] + [min(199, b[1]-a[1]) for a, b in zip(escore_notes[:-1], escore_notes[1:]) if b[1]-a[1] != 0]
14036
+
14037
+ score_times = sorted(set([e[1] for e in escore_notes]))
14038
+
14039
+ #========================================================
14040
+
14041
+ dtimes_chunks = []
14042
+
14043
+ time = 0
14044
+ dtime = []
14045
+
14046
+ for i, dt in enumerate(dtimes):
14047
+ time += dt
14048
+ dtime.append(dt)
14049
+
14050
+ if time >= bar_dtime:
14051
+ dtimes_chunks.append(dtime)
14052
+
14053
+ time = 0
14054
+ dtime = []
14055
+
14056
+ dtimes_chunks.append(dtime)
14057
+
14058
+ #========================================================
14059
+
14060
+ fixed_times = []
14061
+
14062
+ time = 0
14063
+
14064
+ for i, dt in enumerate(dtimes_chunks):
14065
+
14066
+ adj_dt = proportional_adjust(dt,
14067
+ bar_dtime,
14068
+ dtimes_adj_thresh
14069
+ )
14070
+
14071
+ for t in adj_dt:
14072
+
14073
+ time += t
14074
+
14075
+ fixed_times.append(time)
14076
+
14077
+ #========================================================
14078
+
14079
+ output_score = []
14080
+
14081
+ for i, c in enumerate(cscore):
14082
+
14083
+ cc = copy.deepcopy(c)
14084
+ time = fixed_times[i]
14085
+
14086
+ for e in cc:
14087
+ e[1] = time
14088
+
14089
+ output_score.append(e)
14090
+
14091
+ #========================================================
14092
+
14093
+ output_score = fix_escore_notes_durations(output_score,
14094
+ min_notes_gap=min_dur_gap
14095
+ )
14096
+
14097
+ #========================================================
14098
+
14099
+ return output_score
14100
+
14101
+ ###################################################################################
14102
+
14103
+ def check_monophonic_melody(escore_notes,
14104
+ times_idx=1,
14105
+ durs_idx=2
14106
+ ):
14107
+
14108
+ bcount = 0
14109
+
14110
+ for i in range(len(escore_notes)-1):
14111
+ if escore_notes[i][times_idx]+escore_notes[i][durs_idx] > escore_notes[i+1][times_idx]:
14112
+ bcount += 1
14113
+
14114
+ return bcount / len(escore_notes)
14115
+
14116
+ ###################################################################################
14117
+
14118
+ def longest_common_chunk(list1, list2):
14119
+
14120
+ base, mod = 257, 10**9 + 7
14121
+ max_len = min(len(list1), len(list2))
14122
+
14123
+ def get_hashes(seq, size):
14124
+
14125
+ h, power = 0, 1
14126
+ hashes = set()
14127
+
14128
+ for i in range(size):
14129
+ h = (h * base + seq[i]) % mod
14130
+ power = (power * base) % mod
14131
+
14132
+ hashes.add(h)
14133
+
14134
+ for i in range(size, len(seq)):
14135
+ h = (h * base - seq[i - size] * power + seq[i]) % mod
14136
+ hashes.add(h)
14137
+
14138
+ return hashes
14139
+
14140
+ def find_match(size):
14141
+
14142
+ hashes2 = get_hashes(list2, size)
14143
+ h, power = 0, 1
14144
+
14145
+ for i in range(size):
14146
+ h = (h * base + list1[i]) % mod
14147
+ power = (power * base) % mod
14148
+
14149
+ if h in hashes2:
14150
+ return list1[:size]
14151
+
14152
+ for i in range(size, len(list1)):
14153
+ h = (h * base - list1[i - size] * power + list1[i]) % mod
14154
+ if h in hashes2:
14155
+ return list1[i - size + 1:i + 1]
14156
+
14157
+ return []
14158
+
14159
+ left, right = 0, max_len
14160
+ result = []
14161
+
14162
+ while left <= right:
14163
+ mid = (left + right) // 2
14164
+ chunk = find_match(mid)
14165
+
14166
+ if chunk:
14167
+ result = chunk
14168
+ left = mid + 1
14169
+ else:
14170
+
14171
+ right = mid - 1
14172
+
14173
+ return result
14174
+
14175
+ ###################################################################################
14176
+
14177
+ def detect_plateaus(data, min_len=2, tol=0.0):
14178
+
14179
+ plateaus = []
14180
+ n = len(data)
14181
+ if n < min_len:
14182
+ return plateaus
14183
+
14184
+ min_deque = deque()
14185
+ max_deque = deque()
14186
+
14187
+ start = 0
14188
+ idx = 0
14189
+
14190
+ while idx < n:
14191
+ v = data[idx]
14192
+
14193
+ if not isinstance(v, (int, float)) or math.isnan(v):
14194
+
14195
+ if idx - start >= min_len:
14196
+ plateaus.append(data[start:idx])
14197
+
14198
+ idx += 1
14199
+ start = idx
14200
+ min_deque.clear()
14201
+ max_deque.clear()
14202
+
14203
+ continue
14204
+
14205
+ while max_deque and data[max_deque[-1]] <= v:
14206
+ max_deque.pop()
14207
+
14208
+ max_deque.append(idx)
14209
+
14210
+ while min_deque and data[min_deque[-1]] >= v:
14211
+ min_deque.pop()
14212
+
14213
+ min_deque.append(idx)
14214
+
14215
+ if data[max_deque[0]] - data[min_deque[0]] > tol:
14216
+
14217
+ if idx - start >= min_len:
14218
+ plateaus.append(data[start:idx])
14219
+
14220
+ start = idx
14221
+
14222
+ min_deque.clear()
14223
+ max_deque.clear()
14224
+
14225
+ max_deque.append(idx)
14226
+ min_deque.append(idx)
14227
+
14228
+ idx += 1
14229
+
14230
+ if n - start >= min_len:
14231
+ plateaus.append(data[start:n])
14232
+
14233
+ return plateaus
14234
+
14235
+ ###################################################################################
14236
+
14237
+ def alpha_str_to_toks(s, shift=0, add_seos=False):
14238
+
14239
+ tokens = []
14240
+
14241
+ if add_seos:
14242
+ tokens = [53+shift]
14243
+
14244
+ for char in s:
14245
+ if char == ' ':
14246
+ tokens.append(52+shift)
14247
+
14248
+ elif char.isalpha():
14249
+ base = 0 if char.isupper() else 26
14250
+ offset = ord(char.upper()) - ord('A')
14251
+ token = (base + offset + shift) % 52 # wrap A–Z/a–z
14252
+ tokens.append(token)
14253
+
14254
+ if add_seos:
14255
+ tokens.append(53+shift)
14256
+
14257
+ return tokens
14258
+
14259
+ ###################################################################################
14260
+
14261
+ def toks_to_alpha_str(tokens, shift=0, sep=''):
14262
+
14263
+ chars = []
14264
+
14265
+ for token in tokens:
14266
+ if token == 53+shift:
14267
+ continue
14268
+
14269
+ elif token == 52+shift:
14270
+ chars.append(' ')
14271
+
14272
+ elif 0 <= token <= 25:
14273
+ original = (token - shift) % 52
14274
+ chars.append(chr(ord('A') + original))
14275
+
14276
+ elif 26 <= token <= 51:
14277
+ original = (token - shift) % 52
14278
+ chars.append(chr(ord('a') + (original - 26)))
14279
+
14280
+ return sep.join(chars)
14281
+
14282
+ ###################################################################################
14283
+
14284
+ def insert_caps_newlines(text):
14285
+
14286
+ if bool(re.search(r'\b[A-Z][a-z]+\b', text)):
14287
+ pattern = re.compile(r'\s+(?=[A-Z])')
14288
+
14289
+ return pattern.sub('\n', text)
14290
+
14291
+ ###################################################################################
14292
+
14293
+ def insert_newlines(text, every=4):
14294
+
14295
+ count = 0
14296
+ result = []
14297
+
14298
+ for char in text:
14299
+ result.append(char)
14300
+
14301
+ if char == '\n':
14302
+ count += 1
14303
+
14304
+ if count % every == 0:
14305
+ result.append('\n')
14306
+
14307
+ return ''.join(result)
14308
+
14309
+ ###################################################################################
14310
+
14311
+ def symmetric_match_ratio(list_a, list_b, threshold=0):
14312
+
14313
+ a_sorted = sorted(list_a)
14314
+ b_sorted = sorted(list_b)
14315
+
14316
+ i, j = 0, 0
14317
+ matches = 0
14318
+
14319
+ used_a = set()
14320
+ used_b = set()
14321
+
14322
+ while i < len(a_sorted) and j < len(b_sorted):
14323
+ diff = abs(a_sorted[i] - b_sorted[j])
14324
+
14325
+ if diff <= threshold:
14326
+ matches += 1
14327
+ used_a.add(i)
14328
+ used_b.add(j)
14329
+ i += 1
14330
+ j += 1
14331
+
14332
+ elif a_sorted[i] < b_sorted[j]:
14333
+ i += 1
14334
+
14335
+ else:
14336
+ j += 1
14337
+
14338
+ avg_len = (len(list_a) + len(list_b)) / 2
14339
+
14340
+ return matches / avg_len if avg_len > 0 else 0.0
14341
+
14342
+ ###################################################################################
14343
+
14344
+ def escore_notes_to_chords(escore_notes,
14345
+ use_full_chords=False,
14346
+ repair_bad_chords=True,
14347
+ skip_pitches=False,
14348
+ convert_pitches=True,
14349
+ shift_chords=False,
14350
+ return_tones_chords=False
14351
+ ):
14352
+
14353
+ if use_full_chords:
14354
+ CHORDS = ALL_CHORDS_FULL
14355
+
14356
+ else:
14357
+ CHORDS = ALL_CHORDS_SORTED
14358
+
14359
+ sp_score = solo_piano_escore_notes(escore_notes)
14360
+
14361
+ cscore = chordify_score([1000, sp_score])
14362
+
14363
+ chords = []
14364
+
14365
+ for c in cscore:
14366
+ pitches = sorted(set([e[4] for e in c]))
14367
+
14368
+ tones_chord = sorted(set([p % 12 for p in pitches]))
14369
+
14370
+ if repair_bad_chords:
14371
+ if tones_chord not in CHORDS:
14372
+ tones_chord = check_and_fix_tones_chord(tones_chord,
14373
+ use_full_chords=use_full_chords
14374
+ )
14375
+
14376
+ if return_tones_chords:
14377
+ if convert_pitches:
14378
+ chords.append(tones_chord)
14379
+
14380
+ else:
14381
+ if len(pitches) > 1:
14382
+ chords.append(tones_chord)
14383
+
14384
+ else:
14385
+ chords.append([-pitches[0]])
14386
+
14387
+ else:
14388
+ if skip_pitches:
14389
+ if tones_chord in CHORDS:
14390
+ cho_tok = CHORDS.index(tones_chord)
14391
+
14392
+ else:
14393
+ cho_tok = -1
14394
+
14395
+ if len(pitches) > 1:
14396
+ chords.append(cho_tok)
14397
+
14398
+ else:
14399
+ if tones_chord in CHORDS:
14400
+ cho_tok = CHORDS.index(tones_chord)
14401
+
14402
+ else:
14403
+ cho_tok = -1
14404
+
14405
+ if cho_tok != -1:
14406
+ if convert_pitches:
14407
+ if shift_chords:
14408
+ if len(pitches) > 1:
14409
+ chords.append(cho_tok+12)
14410
+
14411
+ else:
14412
+ chords.append(pitches[0] % 12)
14413
+
14414
+ else:
14415
+ chords.append(cho_tok)
14416
+
14417
+ else:
14418
+ if len(pitches) > 1:
14419
+ chords.append(cho_tok+128)
14420
+
14421
+ else:
14422
+ chords.append(pitches[0])
14423
+
14424
+ return chords
14425
+
14426
+ ###################################################################################
14427
+
14428
+ def replace_chords_in_escore_notes(escore_notes,
14429
+ chords_list=[-1],
14430
+ use_full_chords=False,
14431
+ use_shifted_chords=False
14432
+ ):
14433
+
14434
+ if use_full_chords:
14435
+ CHORDS = ALL_CHORDS_FULL
14436
+
14437
+ else:
14438
+ CHORDS = ALL_CHORDS_SORTED
14439
+
14440
+ if use_shifted_chords:
14441
+ shift = 12
14442
+
14443
+ else:
14444
+ shift = 0
14445
+
14446
+ if min(chords_list) >= 0 and max(chords_list) <= len(CHORDS)+shift:
14447
+
14448
+ chords_list_iter = cycle(chords_list)
14449
+
14450
+ nd_score = [e for e in escore_notes if e[3] != 9]
14451
+ d_score = [e for e in escore_notes if e[3] == 9]
14452
+
14453
+ cscore = chordify_score([1000, nd_score])
14454
+
14455
+ new_score = []
14456
+
14457
+ for i, c in enumerate(cscore):
14458
+
14459
+ cur_chord = next(chords_list_iter)
14460
+
14461
+ cc = copy.deepcopy(c)
14462
+
14463
+ if use_shifted_chords:
14464
+ if cur_chord < 12:
14465
+ sub_tones_chord = [cur_chord]
14466
+
14467
+ else:
14468
+ sub_tones_chord = CHORDS[cur_chord-12]
14469
+ else:
14470
+ sub_tones_chord = CHORDS[cur_chord]
14471
+
14472
+ stcho = cycle(sub_tones_chord)
14473
+
14474
+ if len(sub_tones_chord) > len(c):
14475
+ cc = [copy.deepcopy(e) for e in cc for _ in range(len(sub_tones_chord))]
14476
+
14477
+ pseen = []
14478
+
14479
+ for e in cc:
14480
+ st = next(stcho)
14481
+ new_pitch = ((e[4] // 12) * 12) + st
14482
+
14483
+ if [new_pitch, e[6]] not in pseen:
14484
+ e[4] = new_pitch
14485
+
14486
+ new_score.append(e)
14487
+ pseen.append([new_pitch, e[6]])
14488
+
14489
+ final_score = sorted(new_score+d_score, key=lambda x: x[1])
14490
+
14491
+ return final_score
14492
+
14493
+ else:
14494
+ return []
14495
+
14496
+ ###################################################################################
14497
+
14498
+ class Cell:
14499
+ def __init__(self, cost, segments, gaps, prev_dir):
14500
+ self.cost = cost
14501
+ self.segments = segments
14502
+ self.gaps = gaps
14503
+ self.prev_dir = prev_dir
14504
+
14505
+ def align_integer_lists(seq1, seq2):
14506
+
14507
+ n, m = len(seq1), len(seq2)
14508
+
14509
+ if n == 0:
14510
+ return [None]*m, seq2.copy(), sum(abs(x) for x in seq2)
14511
+ if m == 0:
14512
+ return seq1.copy(), [None]*n, sum(abs(x) for x in seq1)
14513
+
14514
+ priority = {'diag': 0, 'up': 1, 'left': 2}
14515
+
14516
+ dp = [
14517
+ [Cell(cost=math.inf, segments=math.inf, gaps=math.inf, prev_dir='') for _ in range(m+1)]
14518
+ for _ in range(n+1)
14519
+ ]
14520
+ dp[0][0] = Cell(cost=0, segments=0, gaps=0, prev_dir='')
14521
+
14522
+ for i in range(1, n+1):
14523
+ prev = dp[i-1][0]
14524
+ new_cost = prev.cost + abs(seq1[i-1])
14525
+ new_seg = prev.segments + (1 if prev.prev_dir != 'up' else 0)
14526
+ new_gaps = prev.gaps + 1
14527
+ dp[i][0] = Cell(new_cost, new_seg, new_gaps, 'up')
14528
+
14529
+ for j in range(1, m+1):
14530
+ prev = dp[0][j-1]
14531
+ new_cost = prev.cost + abs(seq2[j-1])
14532
+ new_seg = prev.segments + (1 if prev.prev_dir != 'left' else 0)
14533
+ new_gaps = prev.gaps + 1
14534
+ dp[0][j] = Cell(new_cost, new_seg, new_gaps, 'left')
14535
+
14536
+ for i in range(1, n+1):
14537
+ for j in range(1, m+1):
14538
+ a, b = seq1[i-1], seq2[j-1]
14539
+
14540
+ c0 = dp[i-1][j-1]
14541
+ cand_diag = Cell(
14542
+ cost = c0.cost + abs(a - b),
14543
+ segments = c0.segments,
14544
+ gaps = c0.gaps,
14545
+ prev_dir = 'diag'
14546
+ )
14547
+
14548
+ c1 = dp[i-1][j]
14549
+ seg1 = c1.segments + (1 if c1.prev_dir != 'up' else 0)
14550
+ cand_up = Cell(
14551
+ cost = c1.cost + abs(a),
14552
+ segments = seg1,
14553
+ gaps = c1.gaps + 1,
14554
+ prev_dir = 'up'
14555
+ )
14556
+
14557
+ c2 = dp[i][j-1]
14558
+ seg2 = c2.segments + (1 if c2.prev_dir != 'left' else 0)
14559
+ cand_left = Cell(
14560
+ cost = c2.cost + abs(b),
14561
+ segments = seg2,
14562
+ gaps = c2.gaps + 1,
14563
+ prev_dir = 'left'
14564
+ )
14565
+
14566
+ best = min(
14567
+ (cand_diag, cand_up, cand_left),
14568
+ key=lambda c: (c.cost, c.segments, c.gaps, priority[c.prev_dir])
14569
+ )
14570
+ dp[i][j] = best
14571
+
14572
+ aligned1 = []
14573
+ aligned2 = []
14574
+ i, j = n, m
14575
+
14576
+ while i > 0 or j > 0:
14577
+ cell = dp[i][j]
14578
+
14579
+ if cell.prev_dir == 'diag':
14580
+ aligned1.append(seq1[i-1])
14581
+ aligned2.append(seq2[j-1])
14582
+ i, j = i-1, j-1
14583
+
14584
+ elif cell.prev_dir == 'up':
14585
+ aligned1.append(seq1[i-1])
14586
+ aligned2.append(None)
14587
+ i -= 1
14588
+
14589
+ else:
14590
+ aligned1.append(None)
14591
+ aligned2.append(seq2[j-1])
14592
+ j -= 1
14593
+
14594
+ aligned1.reverse()
14595
+ aligned2.reverse()
14596
+
14597
+ total_cost = int(dp[n][m].cost)
14598
+
14599
+ return aligned1, aligned2, total_cost
14600
+
14601
+ ###################################################################################
14602
+
14603
+ def most_common_delta_time(escore_notes):
14604
+
14605
+ dscore = delta_score_notes(escore_notes)
14606
+
14607
+ dtimes = [t[1] for t in dscore if t[1] != 0]
14608
+
14609
+ cdtime, count = Counter(dtimes).most_common(1)[0]
14610
+
14611
+ return [cdtime, count / len(dtimes)]
14612
+
14613
+ ###################################################################################
14614
+
14615
+ def delta_tones(escore_notes,
14616
+ ptcs_idx=4
14617
+ ):
14618
+
14619
+ pitches = [p[ptcs_idx] for p in escore_notes]
14620
+ tones = [p % 12 for p in pitches]
14621
+
14622
+ return [b-a for a, b in zip(tones[:-1], tones[1:])]
14623
+
14624
+ ###################################################################################
14625
+
14626
+ def find_divisors(val,
14627
+ reverse=False
14628
+ ):
14629
+
14630
+ if val == 0:
14631
+ return []
14632
+
14633
+ n = abs(val)
14634
+ divisors = set()
14635
+
14636
+ for i in range(1, int(n**0.5) + 1):
14637
+ if n % i == 0:
14638
+ divisors.add(i)
14639
+ divisors.add(n // i)
14640
+
14641
+ return sorted(divisors, reverse=reverse)
14642
+
14643
+ ###################################################################################
14644
+
14645
+ def find_common_divisors(values,
14646
+ reverse=False
14647
+ ):
14648
+
14649
+ if not values:
14650
+ return []
14651
+
14652
+ non_zero = [abs(v) for v in values if v != 0]
14653
+ if not non_zero:
14654
+ return []
14655
+
14656
+ overall_gcd = reduce(gcd, non_zero)
14657
+
14658
+ divisors = set()
14659
+
14660
+ for i in range(1, int(overall_gcd**0.5) + 1):
14661
+ if overall_gcd % i == 0:
14662
+ divisors.add(i)
14663
+ divisors.add(overall_gcd // i)
14664
+
14665
+ return sorted(divisors, reverse=reverse)
14666
+
14667
+ ###################################################################################
14668
+
14669
+ def strings_dict(list_of_strings,
14670
+ verbose=False
14671
+ ):
14672
+
14673
+ str_set = set()
14674
+
14675
+ for st in tqdm.tqdm(list_of_strings, disable=not verbose):
14676
+ for cha in st:
14677
+ str_set.add(cha)
14678
+
14679
+ str_lst = sorted(str_set)
14680
+
14681
+ str_dic = dict(zip(str_lst, range(len(str_lst))))
14682
+ rev_str_dic = {v: k for k, v in str_dic.items()}
14683
+
14684
+ return str_dic, rev_str_dic
14685
+
14686
+ ###################################################################################
14687
+
14688
+ def chords_common_tones_chain(chords,
14689
+ use_full_chords=False
14690
+ ):
14691
+
14692
+ if use_full_chords:
14693
+ CHORDS = ALL_CHORDS_FULL
14694
+
14695
+ else:
14696
+ CHORDS = ALL_CHORDS_SORTED
14697
+
14698
+ tones_chords = [CHORDS[c] for c in chords if 0 <= c < len(CHORDS)]
14699
+
14700
+ n = len(tones_chords)
14701
+
14702
+ if not tones_chords:
14703
+ return []
14704
+
14705
+ if n < 2:
14706
+ return tones_chords
14707
+
14708
+ result = []
14709
+
14710
+ for i in range(n):
14711
+ if i == 0:
14712
+ common = set(tones_chords[0]) & set(tones_chords[1])
14713
+
14714
+ elif i == n - 1:
14715
+ common = set(tones_chords[n - 2]) & set(tones_chords[n - 1])
14716
+
14717
+ else:
14718
+ common = set(tones_chords[i - 1]) & set(tones_chords[i]) & set(tones_chords[i + 1])
14719
+
14720
+ result.append(min(common) if common else -1)
14721
+
14722
+ return result
14723
+
14724
+ ###################################################################################
14725
+
14726
+ def tones_chord_to_int(tones_chord,
14727
+ reverse_bits=True
14728
+ ):
14729
+
14730
+ cbits = tones_chord_to_bits(tones_chord,
14731
+ reverse=reverse_bits
14732
+ )
14733
+
14734
+ cint = bits_to_int(cbits)
14735
+
14736
+ return cint
14737
+
14738
+ ###################################################################################
14739
+
14740
+ def int_to_tones_chord(integer,
14741
+ reverse_bits=True
14742
+ ):
14743
+
14744
+ integer = integer % 4096
14745
+
14746
+ cbits = int_to_bits(integer)
14747
+
14748
+ if reverse_bits:
14749
+ cbits.reverse()
14750
+
14751
+ tones_chord = bits_to_tones_chord(cbits)
14752
+
14753
+ return tones_chord
14754
+
14755
+ ###################################################################################
14756
+
14757
+ def fix_bad_chords_in_escore_notes(escore_notes,
14758
+ use_full_chords=False,
14759
+ return_bad_chords_count=False
14760
+ ):
14761
+
14762
+ if use_full_chords:
14763
+ CHORDS = ALL_CHORDS_FULL
14764
+
14765
+ else:
14766
+ CHORDS = ALL_CHORDS_SORTED
14767
+
14768
+ bcount = 0
14769
+
14770
+ if escore_notes:
14771
+
14772
+ chords = chordify_score([1000, escore_notes])
14773
+
14774
+ fixed_chords = []
14775
+
14776
+ for c in chords:
14777
+ c.sort(key=lambda x: x[3])
14778
+
14779
+ if len(c) > 1:
14780
+
14781
+ groups = groupby(c, key=lambda x: x[3])
14782
+
14783
+ for cha, gr in groups:
14784
+
14785
+ if cha != 9:
14786
+
14787
+ gr = list(gr)
14788
+
14789
+ tones_chord = sorted(set([p[4] % 12 for p in gr]))
14790
+
14791
+ if tones_chord not in CHORDS:
14792
+ tones_chord = check_and_fix_tones_chord(tones_chord,
14793
+ use_full_chords=use_full_chords
14794
+ )
14795
+
14796
+ bcount += 1
14797
+
14798
+ ngr = []
14799
+
14800
+ for n in gr:
14801
+ if n[4] % 12 in tones_chord:
14802
+ ngr.append(n)
14803
+
14804
+ fixed_chords.extend(ngr)
14805
+
14806
+ else:
14807
+ fixed_chords.extend(gr)
14808
+
14809
+
14810
+ else:
14811
+ fixed_chords.extend(c)
14812
+
14813
+ fixed_chords.sort(key=lambda x: (x[1], -x[4]))
14814
+
14815
+ if return_bad_chords_count:
14816
+ return fixed_chords, bcount
14817
+
14818
+ else:
14819
+ return fixed_chords
14820
+
14821
+ else:
14822
+ if return_bad_chords_count:
14823
+ return escore_notes, bcount
14824
+
14825
+ else:
14826
+ return escore_notes
14827
+
14828
+ ###################################################################################
14829
+
14830
+ def remove_events_from_escore_notes(escore_notes,
14831
+ ele_idx=2,
14832
+ ele_vals=[1],
14833
+ chan_idx=3,
14834
+ skip_drums=True
14835
+ ):
14836
+
14837
+ new_escore_notes = []
14838
+
14839
+ for e in escore_notes:
14840
+ if skip_drums:
14841
+ if e[ele_idx] not in ele_vals or e[chan_idx] == 9:
14842
+ new_escore_notes.append(e)
14843
+
14844
+ else:
14845
+ if e[ele_idx] not in ele_vals:
14846
+ new_escore_notes.append(e)
14847
+
14848
+ return new_escore_notes
14849
+
14850
+ ###################################################################################
14851
+
14852
+ def flatten_spikes(arr):
14853
+
14854
+ if len(arr) < 3:
14855
+ return arr[:]
14856
+
14857
+ result = arr[:]
14858
+
14859
+ for i in range(1, len(arr) - 1):
14860
+ prev, curr, next_ = arr[i - 1], arr[i], arr[i + 1]
14861
+
14862
+ if (prev <= next_ and (curr > prev and curr > next_)) or \
14863
+ (prev >= next_ and (curr < prev and curr < next_)):
14864
+ result[i] = max(min(prev, next_), min(max(prev, next_), curr))
14865
+
14866
+ return result
14867
+
14868
+ ###################################################################################
14869
+
14870
+ def flatten_spikes_advanced(arr, window=1):
14871
+
14872
+ if len(arr) < 3:
14873
+ return arr[:]
14874
+
14875
+ result = arr[:]
14876
+ n = len(arr)
14877
+
14878
+ def is_spike(i):
14879
+ left = arr[i - window:i]
14880
+ right = arr[i + 1:i + 1 + window]
14881
+
14882
+ if not left or not right:
14883
+ return False
14884
+
14885
+ avg_left = sum(left) / len(left)
14886
+ avg_right = sum(right) / len(right)
14887
+
14888
+ if arr[i] > avg_left and arr[i] > avg_right:
14889
+ return True
14890
+
14891
+ if arr[i] < avg_left and arr[i] < avg_right:
14892
+ return True
14893
+
14894
+ return False
14895
+
14896
+ for i in range(window, n - window):
14897
+ if is_spike(i):
14898
+ neighbors = arr[i - window:i] + arr[i + 1:i + 1 + window]
14899
+ result[i] = int(sorted(neighbors)[len(neighbors) // 2])
14900
+
14901
+ return result
14902
+
14903
+ ###################################################################################
14904
+
14905
+ def add_smooth_melody_to_enhanced_score_notes(escore_notes,
14906
+ melody_channel=3,
14907
+ melody_patch=40,
14908
+ melody_start_chord=0,
14909
+ min_notes_gap=0,
14910
+ exclude_durs=[1],
14911
+ adv_flattening=True,
14912
+ extend_durs=True,
14913
+ max_mel_vels=127,
14914
+ max_acc_vels=80,
14915
+ return_melody=False
14916
+ ):
14917
+
14918
+ escore_notes1 = remove_duplicate_pitches_from_escore_notes(escore_notes)
14919
+
14920
+ escore_notes2 = fix_escore_notes_durations(escore_notes1,
14921
+ min_notes_gap=min_notes_gap
14922
+ )
14923
+
14924
+ escore_notes3 = fix_bad_chords_in_escore_notes(escore_notes2)
14925
+
14926
+ escore_notes4 = remove_events_from_escore_notes(escore_notes3,
14927
+ ele_vals=exclude_durs
14928
+ )
14929
+
14930
+ escore_notes5 = add_expressive_melody_to_enhanced_score_notes(escore_notes4,
14931
+ melody_channel=melody_channel,
14932
+ melody_patch=melody_patch,
14933
+ melody_start_chord=melody_start_chord,
14934
+ return_melody=True,
14935
+ )
14936
+
14937
+ mel_score = remove_events_from_escore_notes(escore_notes5,
14938
+ ele_vals=exclude_durs
14939
+ )
14940
+
14941
+ pitches = [p[4] for p in mel_score]
14942
+
14943
+ if adv_flattening:
14944
+ res = flatten_spikes_advanced(pitches)
14945
+
14946
+ else:
14947
+ res = flatten_spikes(pitches)
14948
+
14949
+ mel_score3 = copy.deepcopy(mel_score)
14950
+
14951
+ for i, e in enumerate(mel_score3):
14952
+ e[4] = res[i]
14953
+
14954
+ mel_score3 = fix_monophonic_score_durations(merge_melody_notes(mel_score3),
14955
+ extend_durs=extend_durs
14956
+ )
14957
+
14958
+ adjust_score_velocities(mel_score3, max_mel_vels)
14959
+ adjust_score_velocities(escore_notes4, max_acc_vels)
14960
+
14961
+ if return_melody:
14962
+ return sorted(mel_score3, key=lambda x: (x[1], -x[4]))
14963
+
14964
+ else:
14965
+ return sorted(mel_score3 + escore_notes4, key=lambda x: (x[1], -x[4]))
14966
+
14967
+ ###################################################################################
14968
+
14969
+ def sorted_chords_to_full_chords(chords):
14970
+
14971
+ cchords = []
14972
+
14973
+ for c in chords:
14974
+ tones_chord = ALL_CHORDS_SORTED[c]
14975
+
14976
+ if tones_chord not in ALL_CHORDS_FULL:
14977
+ tones_chord = check_and_fix_tones_chord(tones_chord)
14978
+
14979
+ cchords.append(ALL_CHORDS_FULL.index(tones_chord))
14980
+
14981
+ return cchords
14982
+
14983
+ ###################################################################################
14984
+
14985
+ def full_chords_to_sorted_chords(chords):
14986
+
14987
+ cchords = []
14988
+
14989
+ for c in chords:
14990
+ tones_chord = ALL_CHORDS_FULL[c]
14991
+
14992
+ if tones_chord not in ALL_CHORDS_SORTED:
14993
+ tones_chord = check_and_fix_tones_chord(tones_chord, use_full_chords=False)
14994
+
14995
+ cchords.append(ALL_CHORDS_SORTED.index(tones_chord))
14996
+
14997
+ return cchords
14998
+
14999
+ ###################################################################################
15000
+
15001
+ def chords_to_escore_notes(chords,
15002
+ use_full_chords=False,
15003
+ chords_dtime=500,
15004
+ add_melody=True,
15005
+ add_texture=True,
15006
+ ):
15007
+
15008
+ if use_full_chords:
15009
+ CHORDS = ALL_CHORDS_FULL
15010
+
15011
+ else:
15012
+ CHORDS = ALL_CHORDS_SORTED
15013
+
15014
+ score = []
15015
+
15016
+ dtime = 0
15017
+
15018
+ dur = chords_dtime
15019
+
15020
+ for c in chords:
15021
+
15022
+ if add_melody:
15023
+ score.append(['note', dtime, dur, 3, CHORDS[c][0]+72, 115+CHORDS[c][0], 40])
15024
+
15025
+ for cc in CHORDS[c]:
15026
+ score.append(['note', dtime, dur, 0, cc+48, 30+cc+48, 0])
15027
+
15028
+ if random.randint(0, 1) and add_texture:
15029
+ score.append(['note', dtime, dur, 0, cc+60, 20+cc+60, 0])
15030
+
15031
+ dtime += chords_dtime
15032
+
15033
+ return score
15034
+
15035
+ ###################################################################################
15036
+
15037
  print('Module loaded!')
15038
  print('=' * 70)
15039
  print('Enjoy! :)')
midi_to_colab_audio.py CHANGED
@@ -5,14 +5,14 @@ r'''#===========================================================================
5
  # Converts any MIDI file to raw audio which is compatible
6
  # with Google Colab or HUgging Face Gradio
7
  #
8
- # Version 1.0
9
  #
10
- # Includes full source code of MIDI, pyfluidsynth, and midi_synthesizer Python modules
11
  #
12
- # Original source code for all modules was retrieved on 10/23/2023
13
  #
14
  # Project Los Angeles
15
- # Tegridy Code 2023
16
  #
17
  #===================================================================================================================
18
  #
@@ -1773,7 +1773,7 @@ def _encode(events_lol, unknown_callback=None, never_add_eot=False,
1773
 
1774
  Python bindings for FluidSynth
1775
 
1776
- Copyright 2008, Nathan Whitehead <nwhitehe@gmail.com>
1777
 
1778
 
1779
  Released under the LGPL
@@ -1790,27 +1790,67 @@ def _encode(events_lol, unknown_callback=None, never_add_eot=False,
1790
  ================================================================================
1791
  """
1792
 
1793
- from ctypes import *
1794
- from ctypes.util import find_library
1795
  import os
1796
-
1797
- # A short circuited or expression to find the FluidSynth library
1798
- # (mostly needed for Windows distributions of libfluidsynth supplied with QSynth)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1799
 
1800
  # DLL search method changed in Python 3.8
1801
  # https://docs.python.org/3/library/os.html#os.add_dll_directory
1802
- if hasattr(os, 'add_dll_directory'):
1803
  os.add_dll_directory(os.getcwd())
 
 
 
1804
 
1805
- lib = find_library('fluidsynth') or \
1806
- find_library('libfluidsynth') or \
1807
- find_library('libfluidsynth-3') or \
1808
- find_library('libfluidsynth-2') or \
1809
- find_library('libfluidsynth-1')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1810
 
1811
- if lib is None:
1812
  raise ImportError("Couldn't find the FluidSynth library.")
1813
 
 
 
1814
  # Dynamically link the FluidSynth library
1815
  # Architecture (32-/64-bit) must match your Python version
1816
  _fl = CDLL(lib)
@@ -1829,7 +1869,7 @@ def cfunc(name, result, *args):
1829
  return None
1830
 
1831
  # Bump this up when changing the interface for users
1832
- api_version = '1.3.1'
1833
 
1834
  # Function prototypes for C versions of functions
1835
 
@@ -1843,10 +1883,7 @@ fluid_version = cfunc('fluid_version', c_void_p,
1843
 
1844
  majver = c_int()
1845
  fluid_version(majver, c_int(), c_int())
1846
- if majver.value > 1:
1847
- FLUIDSETTING_EXISTS = FLUID_OK
1848
- else:
1849
- FLUIDSETTING_EXISTS = 1
1850
 
1851
  # fluid settings
1852
  new_fluid_settings = cfunc('new_fluid_settings', c_void_p)
@@ -2086,9 +2123,18 @@ fluid_synth_set_chorus_level = cfunc('fluid_synth_set_chorus_level', c_int,
2086
  ('synth', c_void_p, 1),
2087
  ('level', c_double, 1))
2088
 
 
 
 
 
 
 
 
 
2089
  fluid_synth_set_chorus_type = cfunc('fluid_synth_set_chorus_type', c_int,
2090
  ('synth', c_void_p, 1),
2091
  ('type', c_int, 1))
 
2092
  fluid_synth_get_reverb_roomsize = cfunc('fluid_synth_get_reverb_roomsize', c_double,
2093
  ('synth', c_void_p, 1))
2094
 
@@ -2220,6 +2266,77 @@ fluid_midi_event_get_value = cfunc('fluid_midi_event_get_value', c_int,
2220
  fluid_midi_event_get_velocity = cfunc('fluid_midi_event_get_velocity', c_int,
2221
  ('evt', c_void_p, 1))
2222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2223
  # fluid_player_status returned by fluid_player_get_status()
2224
  FLUID_PLAYER_READY = 0
2225
  FLUID_PLAYER_PLAYING = 1
@@ -2281,6 +2398,9 @@ new_fluid_midi_driver = cfunc('new_fluid_midi_driver', c_void_p,
2281
  ('handler', CFUNCTYPE(c_int, c_void_p, c_void_p), 1),
2282
  ('event_handler_data', c_void_p, 1))
2283
 
 
 
 
2284
 
2285
  # fluid midi router rule
2286
  class fluid_midi_router_t(Structure):
@@ -2342,6 +2462,16 @@ fluid_midi_router_add_rule = cfunc('fluid_midi_router_add_rule', c_int,
2342
  ('rule', c_void_p, 1),
2343
  ('type', c_int, 1))
2344
 
 
 
 
 
 
 
 
 
 
 
2345
  # fluidsynth 2.x
2346
  new_fluid_cmd_handler=cfunc('new_fluid_cmd_handler', c_void_p,
2347
  ('synth', c_void_p, 1),
@@ -2416,6 +2546,7 @@ class Synth:
2416
  self.audio_driver = None
2417
  self.midi_driver = None
2418
  self.router = None
 
2419
  def setting(self, opt, val):
2420
  """change an arbitrary synth setting, type-smart"""
2421
  if isinstance(val, (str, bytes)):
@@ -2451,11 +2582,11 @@ class Synth:
2451
  see http://www.fluidsynth.org/api/fluidsettings.xml for allowed values and defaults by platform
2452
  """
2453
  driver = driver or self.get_setting('audio.driver')
2454
- device = device or self.get_setting('audio.%s.device' % driver)
2455
  midi_driver = midi_driver or self.get_setting('midi.driver')
2456
 
2457
  self.setting('audio.driver', driver)
2458
- self.setting('audio.%s.device' % driver, device)
2459
  self.audio_driver = new_fluid_audio_driver(self.settings, self.synth)
2460
  self.setting('midi.driver', midi_driver)
2461
  self.router = new_fluid_midi_router(self.settings, fluid_synth_handle_midi_event, self.synth)
@@ -2463,7 +2594,7 @@ class Synth:
2463
  new_fluid_cmd_handler(self.synth, self.router)
2464
  else:
2465
  fluid_synth_set_midi_router(self.synth, self.router)
2466
- if midi_router == None: ## Use fluidsynth to create a MIDI event handler
2467
  self.midi_driver = new_fluid_midi_driver(self.settings, fluid_midi_router_handle_midi_event, self.router)
2468
  self.custom_router_callback = None
2469
  else: ## Supply an external MIDI event handler
@@ -2474,6 +2605,8 @@ class Synth:
2474
  def delete(self):
2475
  if self.audio_driver:
2476
  delete_fluid_audio_driver(self.audio_driver)
 
 
2477
  delete_fluid_synth(self.synth)
2478
  delete_fluid_settings(self.settings)
2479
  def sfload(self, filename, update_midi_preset=0):
@@ -2518,8 +2651,7 @@ class Synth:
2518
  return None
2519
  return fluid_preset_get_name(preset).decode('ascii')
2520
  else:
2521
- (sfontid, banknum, presetnum, presetname) = self.channel_info(chan)
2522
- return presetname
2523
  def router_clear(self):
2524
  if self.router is not None:
2525
  fluid_midi_router_clear_rules(self.router)
@@ -2570,16 +2702,16 @@ class Synth:
2570
  if fluid_synth_set_reverb is not None:
2571
  return fluid_synth_set_reverb(self.synth, roomsize, damping, width, level)
2572
  else:
2573
- set=0
2574
  if roomsize>=0:
2575
- set+=0b0001
2576
  if damping>=0:
2577
- set+=0b0010
2578
  if width>=0:
2579
- set+=0b0100
2580
  if level>=0:
2581
- set+=0b1000
2582
- return fluid_synth_set_reverb_full(self.synth, set, roomsize, damping, width, level)
2583
  def set_chorus(self, nr=-1, level=-1.0, speed=-1.0, depth=-1.0, type=-1):
2584
  """
2585
  nr Chorus voice count (0-99, CPU time consumption proportional to this value)
@@ -2632,17 +2764,17 @@ class Synth:
2632
  if fluid_synth_set_chorus_level is not None:
2633
  return fluid_synth_set_chorus_level(self.synth, level)
2634
  else:
2635
- return self.set_chorus(leve=level)
2636
  def set_chorus_speed(self, speed):
2637
  if fluid_synth_set_chorus_speed is not None:
2638
  return fluid_synth_set_chorus_speed(self.synth, speed)
2639
  else:
2640
  return self.set_chorus(speed=speed)
2641
- def set_chorus_depth(self, depth):
2642
  if fluid_synth_set_chorus_depth is not None:
2643
- return fluid_synth_set_chorus_depth(self.synth, depth)
2644
  else:
2645
- return self.set_chorus(depth=depth)
2646
  def set_chorus_type(self, type):
2647
  if fluid_synth_set_chorus_type is not None:
2648
  return fluid_synth_set_chorus_type(self.synth, type)
@@ -2694,10 +2826,10 @@ class Synth:
2694
  A pitch bend value of 0 is no pitch change from default.
2695
  A value of -2048 is 1 semitone down.
2696
  A value of 2048 is 1 semitone up.
2697
- Maximum values are -8192 to +8192 (transposing by 4 semitones).
2698
 
2699
  """
2700
- return fluid_synth_pitch_bend(self.synth, chan, val + 8192)
2701
  def cc(self, chan, ctrl, val):
2702
  """Send control change value
2703
 
@@ -2747,8 +2879,15 @@ class Synth:
2747
 
2748
  """
2749
  return fluid_synth_write_s16_stereo(self.synth, len)
2750
- def tuning_dump(self, bank, prog, pitch):
2751
- return fluid_synth_tuning_dump(self.synth, bank, prog, name.encode(), length(name), pitch)
 
 
 
 
 
 
 
2752
 
2753
  def midi_event_get_type(self, event):
2754
  return fluid_midi_event_get_type(event)
@@ -2767,17 +2906,20 @@ class Synth:
2767
 
2768
  def play_midi_file(self, filename):
2769
  self.player = new_fluid_player(self.synth)
2770
- if self.player == None: return FLUID_FAILED
2771
- if self.custom_router_callback != None:
 
2772
  fluid_player_set_playback_callback(self.player, self.custom_router_callback, self.synth)
2773
  status = fluid_player_add(self.player, filename.encode())
2774
- if status == FLUID_FAILED: return status
 
2775
  status = fluid_player_play(self.player)
2776
  return status
2777
 
2778
  def play_midi_stop(self):
2779
  status = fluid_player_stop(self.player)
2780
- if status == FLUID_FAILED: return status
 
2781
  status = fluid_player_seek(self.player, 0)
2782
  delete_fluid_player(self.player)
2783
  return status
@@ -2785,7 +2927,151 @@ class Synth:
2785
  def player_set_tempo(self, tempo_type, tempo):
2786
  return fluid_player_set_tempo(self.player, tempo_type, tempo)
2787
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2789
 
2790
  class Sequencer:
2791
  def __init__(self, time_scale=1000, use_system_timer=True):
@@ -2802,14 +3088,14 @@ class Sequencer:
2802
  def register_fluidsynth(self, synth):
2803
  response = fluid_sequencer_register_fluidsynth(self.sequencer, synth.synth)
2804
  if response == FLUID_FAILED:
2805
- raise Error("Registering fluid synth failed")
2806
  return response
2807
 
2808
  def register_client(self, name, callback, data=None):
2809
  c_callback = CFUNCTYPE(None, c_uint, c_void_p, c_void_p, c_void_p)(callback)
2810
  response = fluid_sequencer_register_client(self.sequencer, name.encode(), c_callback, data)
2811
  if response == FLUID_FAILED:
2812
- raise Error("Registering client failed")
2813
 
2814
  # store in a list to prevent garbage collection
2815
  self.client_callbacks.append(c_callback)
@@ -2849,7 +3135,7 @@ class Sequencer:
2849
  def _schedule_event(self, evt, time, absolute=True):
2850
  response = fluid_sequencer_send_at(self.sequencer, evt, time, absolute)
2851
  if response == FLUID_FAILED:
2852
- raise Error("Scheduling event failed")
2853
 
2854
  def get_tick(self):
2855
  return fluid_sequencer_get_tick(self.sequencer)
@@ -2868,123 +3154,307 @@ def raw_audio_string(data):
2868
 
2869
  """
2870
  import numpy
2871
- return (data.astype(numpy.int16)).tostring()
2872
 
2873
  #===============================================================================
2874
 
2875
  import numpy as np
2876
  import wave
2877
 
2878
- def midi_opus_to_colab_audio(midi_opus,
2879
- soundfont_path='/usr/share/sounds/sf2/FluidR3_GM.sf2',
2880
- sample_rate=16000, # 44100
2881
- volume_scale=10,
2882
- trim_silence=True,
2883
- silence_threshold=0.1,
2884
- output_for_gradio=False,
2885
- write_audio_to_WAV=''
2886
- ):
2887
-
2888
- def normalize_volume(matrix, factor=10):
2889
- norm = np.linalg.norm(matrix)
2890
- matrix = matrix/norm # normalized matrix
2891
- mult_matrix = matrix * factor
2892
- final_matrix = np.clip(mult_matrix, -1.0, 1.0)
2893
- return final_matrix
2894
 
2895
- if midi_opus[1]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2896
 
2897
- ticks_per_beat = midi_opus[0]
2898
- event_list = []
2899
- for track_idx, track in enumerate(midi_opus[1:]):
2900
- abs_t = 0
2901
- for event in track:
2902
- abs_t += event[1]
2903
- event_new = [*event]
2904
- event_new[1] = abs_t
2905
- event_list.append(event_new)
2906
- event_list = sorted(event_list, key=lambda e: e[1])
2907
-
2908
- tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
2909
- ss = np.empty((0, 2), dtype=np.int16)
2910
- fl = Synth(samplerate=float(sample_rate))
2911
- sfid = fl.sfload(soundfont_path)
2912
- last_t = 0
2913
- for c in range(16):
2914
- fl.program_select(c, sfid, 128 if c == 9 else 0, 0)
2915
- for event in event_list:
2916
- name = event[0]
2917
- sample_len = int(((event[1] / ticks_per_beat) * tempo / (10 ** 6)) * sample_rate)
2918
- sample_len -= int(((last_t / ticks_per_beat) * tempo / (10 ** 6)) * sample_rate)
2919
- last_t = event[1]
2920
- if sample_len > 0:
2921
- sample = fl.get_samples(sample_len).reshape(sample_len, 2)
2922
- ss = np.concatenate([ss, sample])
2923
- if name == "set_tempo":
2924
- tempo = event[2]
2925
- elif name == "patch_change":
2926
- c, p = event[2:4]
2927
- fl.program_select(c, sfid, 128 if c == 9 else 0, p)
2928
- elif name == "control_change":
2929
- c, cc, v = event[2:5]
2930
- fl.cc(c, cc, v)
2931
- elif name == "note_on" and event[3] > 0:
2932
- c, p, v = event[2:5]
2933
- fl.noteon(c, p, v)
2934
- elif name == "note_off" or (name == "note_on" and event[3] == 0):
2935
- c, p = event[2:4]
2936
- fl.noteoff(c, p)
2937
-
2938
- fl.delete()
2939
- if ss.shape[0] > 0:
2940
- max_val = np.abs(ss).max()
2941
- if max_val != 0:
2942
- ss = (ss / max_val) * np.iinfo(np.int16).max
2943
- ss = ss.astype(np.int16)
2944
-
2945
- if trim_silence:
2946
- threshold = np.std(np.abs(ss)) * silence_threshold
2947
- exceeded_thresh = np.abs(ss) > threshold
2948
- if np.any(exceeded_thresh):
2949
- last_idx = np.where(exceeded_thresh)[0][-1]
2950
- ss = ss[:last_idx+1]
2951
-
2952
- if output_for_gradio:
2953
- return ss
2954
-
2955
- ss = ss.swapaxes(1, 0)
2956
 
2957
- raw_audio = normalize_volume(ss, volume_scale)
2958
-
2959
- if write_audio_to_WAV != '':
2960
 
2961
- r_audio = raw_audio.T
 
2962
 
2963
- r_audio = np.int16(r_audio / np.max(np.abs(r_audio)) * 32767)
 
 
 
 
2964
 
2965
- with wave.open(write_audio_to_WAV, 'w') as wf:
2966
- wf.setframerate(sample_rate)
2967
- wf.setsampwidth(2)
2968
- wf.setnchannels(r_audio.shape[1])
2969
- wf.writeframes(r_audio)
 
 
 
2970
 
2971
- return raw_audio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2972
 
2973
  else:
2974
  return None
2975
 
2976
- def midi_to_colab_audio(midi_file,
2977
- soundfont_path='/usr/share/sounds/sf2/FluidR3_GM.sf2',
2978
- sample_rate=16000, # 44100
2979
- volume_scale=10,
 
 
2980
  trim_silence=True,
2981
  silence_threshold=0.1,
 
 
 
 
 
 
 
 
 
 
 
 
2982
  output_for_gradio=False,
2983
- write_audio_to_WAV=False
2984
- ):
2985
-
2986
- '''
2987
-
2988
  Returns raw audio to pass to IPython.disaply.Audio func
2989
 
2990
  Example usage:
@@ -2992,99 +3462,176 @@ def midi_to_colab_audio(midi_file,
2992
  from IPython.display import Audio
2993
 
2994
  display(Audio(raw_audio, rate=16000, normalize=False))
2995
-
2996
- '''
2997
-
2998
- def normalize_volume(matrix, factor=10):
2999
- norm = np.linalg.norm(matrix)
3000
- matrix = matrix/norm # normalized matrix
3001
- mult_matrix = matrix * factor
3002
- final_matrix = np.clip(mult_matrix, -1.0, 1.0)
3003
- return final_matrix
3004
-
3005
- midi_opus = midi2opus(open(midi_file, 'rb').read())
3006
 
3007
- if midi_opus[1]:
 
 
 
3008
 
3009
- ticks_per_beat = midi_opus[0]
3010
- event_list = []
3011
- for track_idx, track in enumerate(midi_opus[1:]):
3012
- abs_t = 0
3013
- for event in track:
3014
- abs_t += event[1]
3015
- event_new = [*event]
3016
- event_new[1] = abs_t
3017
- event_list.append(event_new)
3018
- event_list = sorted(event_list, key=lambda e: e[1])
3019
-
3020
- tempo = int((60 / 120) * 10 ** 6) # default 120 bpm
3021
- ss = np.empty((0, 2), dtype=np.int16)
3022
- fl = Synth(samplerate=float(sample_rate))
3023
- sfid = fl.sfload(soundfont_path)
3024
- last_t = 0
3025
- for c in range(16):
3026
- fl.program_select(c, sfid, 128 if c == 9 else 0, 0)
3027
- for event in event_list:
3028
- name = event[0]
3029
- sample_len = int(((event[1] / ticks_per_beat) * tempo / (10 ** 6)) * sample_rate)
3030
- sample_len -= int(((last_t / ticks_per_beat) * tempo / (10 ** 6)) * sample_rate)
3031
- last_t = event[1]
3032
- if sample_len > 0:
3033
- sample = fl.get_samples(sample_len).reshape(sample_len, 2)
3034
- ss = np.concatenate([ss, sample])
3035
- if name == "set_tempo":
3036
- tempo = event[2]
3037
- elif name == "patch_change":
3038
- c, p = event[2:4]
3039
- fl.program_select(c, sfid, 128 if c == 9 else 0, p)
3040
- elif name == "control_change":
3041
- c, cc, v = event[2:5]
3042
- fl.cc(c, cc, v)
3043
- elif name == "note_on" and event[3] > 0:
3044
- c, p, v = event[2:5]
3045
- fl.noteon(c, p, v)
3046
- elif name == "note_off" or (name == "note_on" and event[3] == 0):
3047
- c, p = event[2:4]
3048
- fl.noteoff(c, p)
3049
-
3050
- fl.delete()
3051
- if ss.shape[0] > 0:
3052
- max_val = np.abs(ss).max()
3053
- if max_val != 0:
3054
- ss = (ss / max_val) * np.iinfo(np.int16).max
3055
- ss = ss.astype(np.int16)
3056
-
3057
- if trim_silence:
3058
- threshold = np.std(np.abs(ss)) * silence_threshold
3059
- exceeded_thresh = np.abs(ss) > threshold
3060
- if np.any(exceeded_thresh):
3061
- last_idx = np.where(exceeded_thresh)[0][-1]
3062
- ss = ss[:last_idx+1]
3063
-
3064
- if output_for_gradio:
3065
- return ss
3066
 
3067
- ss = ss.swapaxes(1, 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3068
 
3069
- raw_audio = normalize_volume(ss, volume_scale)
 
 
3070
 
3071
- if write_audio_to_WAV:
 
3072
 
3073
- filename = midi_file.split('.')[-2] + '.wav'
 
 
 
 
3074
 
3075
- r_audio = raw_audio.T
 
 
 
 
 
3076
 
3077
- r_audio = np.int16(r_audio / np.max(np.abs(r_audio)) * 32767)
 
 
3078
 
3079
- with wave.open(filename, 'w') as wf:
 
 
 
 
 
 
 
 
 
 
3080
  wf.setframerate(sample_rate)
3081
  wf.setsampwidth(2)
3082
- wf.setnchannels(r_audio.shape[1])
3083
- wf.writeframes(r_audio)
 
 
3084
 
3085
- return raw_audio
3086
-
3087
- else:
3088
- return None
3089
-
3090
  #===================================================================================================================
 
5
  # Converts any MIDI file to raw audio which is compatible
6
  # with Google Colab or HUgging Face Gradio
7
  #
8
+ # Version 2.0
9
  #
10
+ # Includes full source code of MIDI and pyfluidsynth
11
  #
12
+ # Original source code for all modules was retrieved on 07/31/2025
13
  #
14
  # Project Los Angeles
15
+ # Tegridy Code 2025
16
  #
17
  #===================================================================================================================
18
  #
 
1773
 
1774
  Python bindings for FluidSynth
1775
 
1776
+ Copyright 2008--2024, Nathan Whitehead <nwhitehe@gmail.com> and others.
1777
 
1778
 
1779
  Released under the LGPL
 
1790
  ================================================================================
1791
  """
1792
 
 
 
1793
  import os
1794
+ from ctypes import (
1795
+ CDLL,
1796
+ CFUNCTYPE,
1797
+ POINTER,
1798
+ Structure,
1799
+ byref,
1800
+ c_char,
1801
+ c_char_p,
1802
+ c_double,
1803
+ c_float,
1804
+ c_int,
1805
+ c_short,
1806
+ c_uint,
1807
+ c_void_p,
1808
+ create_string_buffer,
1809
+ )
1810
+ from ctypes.util import find_library
1811
 
1812
  # DLL search method changed in Python 3.8
1813
  # https://docs.python.org/3/library/os.html#os.add_dll_directory
1814
+ if hasattr(os, 'add_dll_directory'): # Python 3.8+ on Windows only
1815
  os.add_dll_directory(os.getcwd())
1816
+ os.add_dll_directory('C:\\tools\\fluidsynth\\bin')
1817
+ # Workaround bug in find_library, it doesn't recognize add_dll_directory
1818
+ os.environ['PATH'] += ';C:\\tools\\fluidsynth\\bin'
1819
 
1820
+ # A function to find the FluidSynth library
1821
+ # (mostly needed for Windows distributions of libfluidsynth supplied with QSynth)
1822
+ def find_libfluidsynth(debug_print: bool = False) -> str:
1823
+ r"""
1824
+ macOS X64:
1825
+ * 'fluidsynth' was found at /usr/local/opt/fluid-synth/lib/libfluidsynth.dylib.
1826
+ macOS ARM64:
1827
+ * 'fluidsynth' was found at /opt/homebrew/opt/fluid-synth/lib/libfluidsynth.dylib.
1828
+ Ubuntu X86:
1829
+ * 'fluidsynth' was found at libfluidsynth.so.3.
1830
+ Windows X86:
1831
+ * 'libfluidsynth-3' was found at C:\tools\fluidsynth\bin\libfluidsynth-3.dll. --or--
1832
+ * 'fluidsynth-3' was found as C:\tools\fluidsynth\bin\fluidsynth-3.dll. >= v2.4.5
1833
+ * https://github.com/FluidSynth/fluidsynth/issues/1543
1834
+ """
1835
+ libs = "fluidsynth fluidsynth-3 libfluidsynth libfluidsynth-3 libfluidsynth-2 libfluidsynth-1"
1836
+ for lib_name in libs.split():
1837
+ lib = find_library(lib_name)
1838
+ if lib:
1839
+ if debug_print:
1840
+ print(f"'{lib_name}' was found at {lib}.")
1841
+ return lib
1842
+
1843
+ # On macOS on Apple silicon, non-Homebrew Python distributions fail to locate
1844
+ # homebrew-installed instances of FluidSynth. This workaround addresses this.
1845
+ if homebrew_prefix := os.getenv("HOMEBREW_PREFIX"):
1846
+ lib = os.path.join(homebrew_prefix, "lib", "libfluidsynth.dylib")
1847
+ if os.path.exists(lib):
1848
+ return lib
1849
 
 
1850
  raise ImportError("Couldn't find the FluidSynth library.")
1851
 
1852
+ lib = find_libfluidsynth()
1853
+
1854
  # Dynamically link the FluidSynth library
1855
  # Architecture (32-/64-bit) must match your Python version
1856
  _fl = CDLL(lib)
 
1869
  return None
1870
 
1871
  # Bump this up when changing the interface for users
1872
+ api_version = '1.3.5'
1873
 
1874
  # Function prototypes for C versions of functions
1875
 
 
1883
 
1884
  majver = c_int()
1885
  fluid_version(majver, c_int(), c_int())
1886
+ FLUIDSETTING_EXISTS = FLUID_OK if majver.value > 1 else 1
 
 
 
1887
 
1888
  # fluid settings
1889
  new_fluid_settings = cfunc('new_fluid_settings', c_void_p)
 
2123
  ('synth', c_void_p, 1),
2124
  ('level', c_double, 1))
2125
 
2126
+ fluid_synth_set_chorus_speed = cfunc('fluid_synth_set_chorus_speed', c_int,
2127
+ ('synth', c_void_p, 1),
2128
+ ('speed', c_double, 1))
2129
+
2130
+ fluid_synth_set_chorus_depth = cfunc('fluid_synth_set_chorus_depth', c_int,
2131
+ ('synth', c_void_p, 1),
2132
+ ('depth_ms', c_double, 1))
2133
+
2134
  fluid_synth_set_chorus_type = cfunc('fluid_synth_set_chorus_type', c_int,
2135
  ('synth', c_void_p, 1),
2136
  ('type', c_int, 1))
2137
+
2138
  fluid_synth_get_reverb_roomsize = cfunc('fluid_synth_get_reverb_roomsize', c_double,
2139
  ('synth', c_void_p, 1))
2140
 
 
2266
  fluid_midi_event_get_velocity = cfunc('fluid_midi_event_get_velocity', c_int,
2267
  ('evt', c_void_p, 1))
2268
 
2269
+ # fluid modulator
2270
+ new_fluid_mod = cfunc("new_fluid_mod", c_void_p)
2271
+
2272
+ delete_fluid_mod = cfunc("delete_fluid_mod", c_void_p, ("mod", c_void_p, 1))
2273
+
2274
+ fluid_mod_clone = cfunc(
2275
+ "fluid_mod_clone", c_void_p, ("mod", c_void_p, 1), ("src", c_void_p, 1),
2276
+ )
2277
+
2278
+ fluid_mod_get_amount = cfunc("fluid_mod_get_amount", c_void_p, ("mod", c_void_p, 1))
2279
+
2280
+ fluid_mod_get_dest = cfunc("fluid_mod_get_dest", c_void_p, ("mod", c_void_p, 1))
2281
+
2282
+ fluid_mod_get_flags1 = cfunc("fluid_mod_get_flags1", c_void_p, ("mod", c_void_p, 1))
2283
+
2284
+ fluid_mod_get_flags2 = cfunc("fluid_mod_get_flags2", c_void_p, ("mod", c_void_p, 1))
2285
+
2286
+ fluid_mod_get_source1 = cfunc("fluid_mod_get_source1", c_void_p, ("mod", c_void_p, 1))
2287
+
2288
+ fluid_mod_get_source2 = cfunc("fluid_mod_get_source2", c_void_p, ("mod", c_void_p, 1))
2289
+
2290
+ fluid_mod_get_transform = cfunc(
2291
+ "fluid_mod_get_transform", c_void_p, ("mod", c_void_p, 1),
2292
+ )
2293
+
2294
+ fluid_mod_has_dest = cfunc(
2295
+ "fluid_mod_has_dest", c_void_p, ("mod", c_void_p, 1), ("gen", c_uint, 1),
2296
+ )
2297
+
2298
+ fluid_mod_has_source = cfunc(
2299
+ "fluid_mod_has_dest",
2300
+ c_void_p,
2301
+ ("mod", c_void_p, 1),
2302
+ ("cc", c_uint, 1),
2303
+ ("ctrl", c_uint, 1),
2304
+ )
2305
+
2306
+ fluid_mod_set_amount = cfunc(
2307
+ "fluid_mod_set_amount", c_void_p, ("mod", c_void_p, 1), ("amount", c_double, 1),
2308
+ )
2309
+
2310
+ fluid_mod_set_dest = cfunc(
2311
+ "fluid_mod_set_dest", c_void_p, ("mod", c_void_p, 1), ("dst", c_int, 1),
2312
+ )
2313
+
2314
+ fluid_mod_set_source1 = cfunc(
2315
+ "fluid_mod_set_source1",
2316
+ c_void_p,
2317
+ ("mod", c_void_p, 1),
2318
+ ("src", c_int, 1),
2319
+ ("flags", c_int, 1),
2320
+ )
2321
+
2322
+ fluid_mod_set_source2 = cfunc(
2323
+ "fluid_mod_set_source2",
2324
+ c_void_p,
2325
+ ("mod", c_void_p, 1),
2326
+ ("src", c_int, 1),
2327
+ ("flags", c_int, 1),
2328
+ )
2329
+
2330
+ fluid_mod_set_transform = cfunc(
2331
+ "fluid_mod_set_transform", c_void_p, ("mod", c_void_p, 1), ("type", c_int, 1),
2332
+ )
2333
+
2334
+ fluid_mod_sizeof = cfunc("fluid_mod_sizeof", c_void_p)
2335
+
2336
+ fluid_mod_test_identity = cfunc(
2337
+ "fluid_mod_test_identity", c_void_p, ("mod1", c_void_p, 1), ("mod2", c_void_p, 1),
2338
+ )
2339
+
2340
  # fluid_player_status returned by fluid_player_get_status()
2341
  FLUID_PLAYER_READY = 0
2342
  FLUID_PLAYER_PLAYING = 1
 
2398
  ('handler', CFUNCTYPE(c_int, c_void_p, c_void_p), 1),
2399
  ('event_handler_data', c_void_p, 1))
2400
 
2401
+ delete_fluid_midi_driver = cfunc('delete_fluid_midi_driver', None,
2402
+ ('driver', c_void_p, 1))
2403
+
2404
 
2405
  # fluid midi router rule
2406
  class fluid_midi_router_t(Structure):
 
2462
  ('rule', c_void_p, 1),
2463
  ('type', c_int, 1))
2464
 
2465
+ # fluid file renderer
2466
+ new_fluid_file_renderer = cfunc('new_fluid_file_renderer', c_void_p,
2467
+ ('synth', c_void_p, 1))
2468
+
2469
+ delete_fluid_file_renderer = cfunc('delete_fluid_file_renderer', None,
2470
+ ('renderer', c_void_p, 1))
2471
+
2472
+ fluid_file_renderer_process_block = cfunc('fluid_file_renderer_process_block', c_int,
2473
+ ('render', c_void_p, 1))
2474
+
2475
  # fluidsynth 2.x
2476
  new_fluid_cmd_handler=cfunc('new_fluid_cmd_handler', c_void_p,
2477
  ('synth', c_void_p, 1),
 
2546
  self.audio_driver = None
2547
  self.midi_driver = None
2548
  self.router = None
2549
+ self.custom_router_callback = None
2550
  def setting(self, opt, val):
2551
  """change an arbitrary synth setting, type-smart"""
2552
  if isinstance(val, (str, bytes)):
 
2582
  see http://www.fluidsynth.org/api/fluidsettings.xml for allowed values and defaults by platform
2583
  """
2584
  driver = driver or self.get_setting('audio.driver')
2585
+ device = device or self.get_setting(f'audio.{driver}.device')
2586
  midi_driver = midi_driver or self.get_setting('midi.driver')
2587
 
2588
  self.setting('audio.driver', driver)
2589
+ self.setting(f'audio.{driver}.device', device)
2590
  self.audio_driver = new_fluid_audio_driver(self.settings, self.synth)
2591
  self.setting('midi.driver', midi_driver)
2592
  self.router = new_fluid_midi_router(self.settings, fluid_synth_handle_midi_event, self.synth)
 
2594
  new_fluid_cmd_handler(self.synth, self.router)
2595
  else:
2596
  fluid_synth_set_midi_router(self.synth, self.router)
2597
+ if midi_router is None: ## Use fluidsynth to create a MIDI event handler
2598
  self.midi_driver = new_fluid_midi_driver(self.settings, fluid_midi_router_handle_midi_event, self.router)
2599
  self.custom_router_callback = None
2600
  else: ## Supply an external MIDI event handler
 
2605
  def delete(self):
2606
  if self.audio_driver:
2607
  delete_fluid_audio_driver(self.audio_driver)
2608
+ if self.midi_driver:
2609
+ delete_fluid_midi_driver(self.midi_driver)
2610
  delete_fluid_synth(self.synth)
2611
  delete_fluid_settings(self.settings)
2612
  def sfload(self, filename, update_midi_preset=0):
 
2651
  return None
2652
  return fluid_preset_get_name(preset).decode('ascii')
2653
  else:
2654
+ return None
 
2655
  def router_clear(self):
2656
  if self.router is not None:
2657
  fluid_midi_router_clear_rules(self.router)
 
2702
  if fluid_synth_set_reverb is not None:
2703
  return fluid_synth_set_reverb(self.synth, roomsize, damping, width, level)
2704
  else:
2705
+ flags=0
2706
  if roomsize>=0:
2707
+ flags+=0b0001
2708
  if damping>=0:
2709
+ flags+=0b0010
2710
  if width>=0:
2711
+ flags+=0b0100
2712
  if level>=0:
2713
+ flags+=0b1000
2714
+ return fluid_synth_set_reverb_full(self.synth, flags, roomsize, damping, width, level)
2715
  def set_chorus(self, nr=-1, level=-1.0, speed=-1.0, depth=-1.0, type=-1):
2716
  """
2717
  nr Chorus voice count (0-99, CPU time consumption proportional to this value)
 
2764
  if fluid_synth_set_chorus_level is not None:
2765
  return fluid_synth_set_chorus_level(self.synth, level)
2766
  else:
2767
+ return self.set_chorus(level=level)
2768
  def set_chorus_speed(self, speed):
2769
  if fluid_synth_set_chorus_speed is not None:
2770
  return fluid_synth_set_chorus_speed(self.synth, speed)
2771
  else:
2772
  return self.set_chorus(speed=speed)
2773
+ def set_chorus_depth(self, depth_ms):
2774
  if fluid_synth_set_chorus_depth is not None:
2775
+ return fluid_synth_set_chorus_depth(self.synth, depth_ms)
2776
  else:
2777
+ return self.set_chorus(depth=depth_ms)
2778
  def set_chorus_type(self, type):
2779
  if fluid_synth_set_chorus_type is not None:
2780
  return fluid_synth_set_chorus_type(self.synth, type)
 
2826
  A pitch bend value of 0 is no pitch change from default.
2827
  A value of -2048 is 1 semitone down.
2828
  A value of 2048 is 1 semitone up.
2829
+ Maximum values are -8192 to +8191 (transposing by 4 semitones).
2830
 
2831
  """
2832
+ return fluid_synth_pitch_bend(self.synth, chan, max(0, min(val + 8192, 16383)))
2833
  def cc(self, chan, ctrl, val):
2834
  """Send control change value
2835
 
 
2879
 
2880
  """
2881
  return fluid_synth_write_s16_stereo(self.synth, len)
2882
+ def tuning_dump(self, bank, prog):
2883
+ """Get tuning information for given bank and preset
2884
+
2885
+ Return value is an array of length 128 with tuning factors for each MIDI note.
2886
+ Tuning factor of 0.0 in each position is standard tuning. Measured in cents.
2887
+ """
2888
+ pitch = (c_double * 128)()
2889
+ fluid_synth_tuning_dump(self.synth, bank, prog, None, 0, pitch)
2890
+ return pitch[:]
2891
 
2892
  def midi_event_get_type(self, event):
2893
  return fluid_midi_event_get_type(event)
 
2906
 
2907
  def play_midi_file(self, filename):
2908
  self.player = new_fluid_player(self.synth)
2909
+ if self.player is None:
2910
+ return FLUID_FAILED
2911
+ if self.custom_router_callback is not None:
2912
  fluid_player_set_playback_callback(self.player, self.custom_router_callback, self.synth)
2913
  status = fluid_player_add(self.player, filename.encode())
2914
+ if status == FLUID_FAILED:
2915
+ return status
2916
  status = fluid_player_play(self.player)
2917
  return status
2918
 
2919
  def play_midi_stop(self):
2920
  status = fluid_player_stop(self.player)
2921
+ if status == FLUID_FAILED:
2922
+ return status
2923
  status = fluid_player_seek(self.player, 0)
2924
  delete_fluid_player(self.player)
2925
  return status
 
2927
  def player_set_tempo(self, tempo_type, tempo):
2928
  return fluid_player_set_tempo(self.player, tempo_type, tempo)
2929
 
2930
+ def midi2audio(self, midifile, audiofile = "output.wav"):
2931
+ """Convert a midi file to an audio file"""
2932
+ self.setting("audio.file.name", audiofile)
2933
+ player = new_fluid_player(self.synth)
2934
+ fluid_player_add(player, midifile.encode())
2935
+ fluid_player_play(player)
2936
+ renderer = new_fluid_file_renderer(self.synth)
2937
+ while fluid_player_get_status(player) == FLUID_PLAYER_PLAYING:
2938
+ if fluid_file_renderer_process_block(renderer) != FLUID_OK:
2939
+ break
2940
+ delete_fluid_file_renderer(renderer)
2941
+ delete_fluid_player(player)
2942
+
2943
+ # flag values
2944
+ FLUID_MOD_POSITIVE = 0
2945
+ FLUID_MOD_NEGATIVE = 1
2946
+ FLUID_MOD_UNIPOLAR = 0
2947
+ FLUID_MOD_BIPOLAR = 2
2948
+ FLUID_MOD_LINEAR = 0
2949
+ FLUID_MOD_CONCAVE = 4
2950
+ FLUID_MOD_CONVEX = 8
2951
+ FLUID_MOD_SWITCH = 12
2952
+ FLUID_MOD_GC = 0
2953
+ FLUID_MOD_CC = 16
2954
+ FLUID_MOD_SIN = 0x80
2955
+
2956
+ # src values
2957
+ FLUID_MOD_NONE = 0
2958
+ FLUID_MOD_VELOCITY = 2
2959
+ FLUID_MOD_KEY = 3
2960
+ FLUID_MOD_KEYPRESSURE = 10
2961
+ FLUID_MOD_CHANNELPRESSURE = 13
2962
+ FLUID_MOD_PITCHWHEEL = 14
2963
+ FLUID_MOD_PITCHWHEELSENS = 16
2964
+
2965
+ # Transforms
2966
+ FLUID_MOD_TRANSFORM_LINEAR = 0
2967
+ FLUID_MOD_TRANSFORM_ABS = 2
2968
+
2969
+ class Modulator:
2970
+ def __init__(self):
2971
+ """Create new modulator object"""
2972
+ self.mod = new_fluid_mod()
2973
+
2974
+ def clone(self, src):
2975
+ response = fluid_mod_clone(self.mod, src)
2976
+ if response == FLUID_FAILED:
2977
+ raise Exception("Modulation clone failed")
2978
+ return response
2979
 
2980
+ def get_amount(self):
2981
+ response = fluid_mod_get_amount(self.mod)
2982
+ if response == FLUID_FAILED:
2983
+ raise Exception("Modulation amount get failed")
2984
+ return response
2985
+
2986
+ def get_dest(self):
2987
+ response = fluid_mod_get_dest(self.mod)
2988
+ if response == FLUID_FAILED:
2989
+ raise Exception("Modulation destination get failed")
2990
+ return response
2991
+
2992
+ def get_flags1(self):
2993
+ response = fluid_mod_get_flags1(self.mod)
2994
+ if response == FLUID_FAILED:
2995
+ raise Exception("Modulation flags1 get failed")
2996
+ return response
2997
+
2998
+ def get_flags2(self):
2999
+ response = fluid_mod_get_flags2(self.mod)
3000
+ if response == FLUID_FAILED:
3001
+ raise Exception("Modulation flags2 get failed")
3002
+ return response
3003
+
3004
+ def get_source1(self):
3005
+ response = fluid_mod_get_source1(self.mod)
3006
+ if response == FLUID_FAILED:
3007
+ raise Exception("Modulation source1 get failed")
3008
+ return response
3009
+
3010
+ def get_source2(self):
3011
+ response = fluid_mod_get_source2(self.mod)
3012
+ if response == FLUID_FAILED:
3013
+ raise Exception("Modulation source2 get failed")
3014
+ return response
3015
+
3016
+ def get_transform(self):
3017
+ response = fluid_mod_get_transform(self.mod)
3018
+ if response == FLUID_FAILED:
3019
+ raise Exception("Modulation transform get failed")
3020
+ return response
3021
+
3022
+ def has_dest(self, gen):
3023
+ response = fluid_mod_has_dest(self.mod, gen)
3024
+ if response == FLUID_FAILED:
3025
+ raise Exception("Modulation has destination check failed")
3026
+ return response
3027
+
3028
+ def has_source(self, cc, ctrl):
3029
+ response = fluid_mod_has_source(self.mod, cc, ctrl)
3030
+ if response == FLUID_FAILED:
3031
+ raise Exception("Modulation has source check failed")
3032
+ return response
3033
+
3034
+ def set_amount(self, amount):
3035
+ response = fluid_mod_set_amount(self.mod, amount)
3036
+ if response == FLUID_FAILED:
3037
+ raise Exception("Modulation set amount failed")
3038
+ return response
3039
+
3040
+ def set_dest(self, dest):
3041
+ response = fluid_mod_set_dest(self.mod, dest)
3042
+ if response == FLUID_FAILED:
3043
+ raise Exception("Modulation set dest failed")
3044
+ return response
3045
+
3046
+ def set_source1(self, src, flags):
3047
+ response = fluid_mod_set_source1(self.mod, src, flags)
3048
+ if response == FLUID_FAILED:
3049
+ raise Exception("Modulation set source 1 failed")
3050
+ return response
3051
+
3052
+ def set_source2(self, src, flags):
3053
+ response = fluid_mod_set_source2(self.mod, src, flags)
3054
+ if response == FLUID_FAILED:
3055
+ raise Exception("Modulation set source 2 failed")
3056
+ return response
3057
+
3058
+ def set_transform(self, type):
3059
+ response = fluid_mod_set_transform(self.mod, type)
3060
+ if response == FLUID_FAILED:
3061
+ raise Exception("Modulation set transform failed")
3062
+ return response
3063
+
3064
+ def sizeof(self):
3065
+ response = fluid_mod_sizeof()
3066
+ if response == FLUID_FAILED:
3067
+ raise Exception("Modulation sizeof failed")
3068
+ return response
3069
+
3070
+ def test_identity(self, mod2):
3071
+ response = fluid_mod_sizeof(self.mod, mod2)
3072
+ if response == FLUID_FAILED:
3073
+ raise Exception("Modulation identity check failed")
3074
+ return response
3075
 
3076
  class Sequencer:
3077
  def __init__(self, time_scale=1000, use_system_timer=True):
 
3088
  def register_fluidsynth(self, synth):
3089
  response = fluid_sequencer_register_fluidsynth(self.sequencer, synth.synth)
3090
  if response == FLUID_FAILED:
3091
+ raise Exception("Registering fluid synth failed")
3092
  return response
3093
 
3094
  def register_client(self, name, callback, data=None):
3095
  c_callback = CFUNCTYPE(None, c_uint, c_void_p, c_void_p, c_void_p)(callback)
3096
  response = fluid_sequencer_register_client(self.sequencer, name.encode(), c_callback, data)
3097
  if response == FLUID_FAILED:
3098
+ raise Exception("Registering client failed")
3099
 
3100
  # store in a list to prevent garbage collection
3101
  self.client_callbacks.append(c_callback)
 
3135
  def _schedule_event(self, evt, time, absolute=True):
3136
  response = fluid_sequencer_send_at(self.sequencer, evt, time, absolute)
3137
  if response == FLUID_FAILED:
3138
+ raise Exception("Scheduling event failed")
3139
 
3140
  def get_tick(self):
3141
  return fluid_sequencer_get_tick(self.sequencer)
 
3154
 
3155
  """
3156
  import numpy
3157
+ return (data.astype(numpy.int16)).tobytes()
3158
 
3159
  #===============================================================================
3160
 
3161
  import numpy as np
3162
  import wave
3163
 
3164
+ #===============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3165
 
3166
+ def normalize_audio(audio: np.ndarray,
3167
+ method: str = 'peak',
3168
+ target_level_db: float = -1.0,
3169
+ per_channel: bool = False,
3170
+ eps: float = 1e-9
3171
+ ) -> np.ndarray:
3172
+
3173
+ """
3174
+ Normalize audio to a target dBFS level.
3175
+
3176
+ Parameters
3177
+ ----------
3178
+ audio : np.ndarray
3179
+ Float-valued array in range [-1, 1] with shape (channels, samples)
3180
+ or (samples,) for mono.
3181
+ method : {'peak', 'rms'}
3182
+ - 'peak': scale so that max(|audio|) = target_level_lin
3183
+ - 'rms' : scale so that RMS(audio) = target_level_lin
3184
+ target_level_db : float
3185
+ Desired output level, in dBFS (0 dBFS = max digital full scale).
3186
+ e.g. -1.0 dBFS means ~0.8913 linear gain.
3187
+ per_channel : bool
3188
+ If True, normalize each channel independently. Otherwise, use a
3189
+ global measure across all channels.
3190
+ eps : float
3191
+ Small constant to avoid division by zero.
3192
+
3193
+ Returns
3194
+ -------
3195
+ normalized : np.ndarray
3196
+ Audio array of same shape, scaled so that levels meet the target.
3197
+ """
3198
+
3199
+ # Convert target dB to linear gain
3200
+ target_lin = 10 ** (target_level_db / 20.0)
3201
 
3202
+ # Ensure audio is float
3203
+ audio = audio.astype(np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3204
 
3205
+ # if mono, make it (1, N)
3206
+ if audio.ndim == 1:
3207
+ audio = audio[np.newaxis, :]
3208
 
3209
+ # Choose measurement axis
3210
+ axis = 1 if per_channel else None
3211
 
3212
+ if method == 'peak':
3213
+ # Compute peak per channel or global
3214
+ peak = np.max(np.abs(audio), axis=axis, keepdims=True)
3215
+ peak = np.maximum(peak, eps)
3216
+ scales = target_lin / peak
3217
 
3218
+ elif method == 'rms':
3219
+ # Compute RMS per channel or global
3220
+ rms = np.sqrt(np.mean(audio ** 2, axis=axis, keepdims=True))
3221
+ rms = np.maximum(rms, eps)
3222
+ scales = target_lin / rms
3223
+
3224
+ else:
3225
+ raise ValueError(f"Unsupported method '{method}'; choose 'peak' or 'rms'.")
3226
 
3227
+ # Broadcast scales back to audio shape
3228
+ normalized = audio * scales
3229
+
3230
+ # Clip just in case of rounding
3231
+ return np.clip(normalized, -1.0, 1.0)
3232
+
3233
+ #===============================================================================
3234
+
3235
+ def midi_opus_to_colab_audio(midi_opus,
3236
+ soundfont_path='/usr/share/sounds/sf2/FluidR3_GM.sf2',
3237
+ sample_rate=16000, # 44100
3238
+ volume_level_db=-1,
3239
+ trim_silence=True,
3240
+ silence_threshold=0.1,
3241
+ enable_reverb=False,
3242
+ reverb_param_dic={'roomsize': 0,
3243
+ 'damping': 0,
3244
+ 'width': 0,
3245
+ 'level': 0
3246
+ },
3247
+ enable_chorus=False,
3248
+ chorus_param_dic={'nr': 0,
3249
+ 'level': 0,
3250
+ 'speed': 0.1,
3251
+ 'depth': 0,
3252
+ 'type': 0},
3253
+ output_for_gradio=False,
3254
+ write_audio_to_WAV=False,
3255
+ output_WAV_name=''
3256
+ ):
3257
+
3258
+ if midi_opus[1]:
3259
+
3260
+ ticks_per_beat, *tracks = midi_opus
3261
+ if not tracks:
3262
+ return None
3263
+
3264
+ # Flatten & convert delta-times to absolute-time
3265
+ events = []
3266
+ for track in tracks:
3267
+ abs_t = 0
3268
+ for name, dt, *data in track:
3269
+ abs_t += dt
3270
+ events.append([name, abs_t, *data])
3271
+ events.sort(key=lambda e: e[1])
3272
+
3273
+ # Setup FluidSynth
3274
+ fl = Synth(samplerate=float(sample_rate))
3275
+ sfid = fl.sfload(soundfont_path)
3276
+ for chan in range(16):
3277
+ # channel 9 = percussion GM bank 128
3278
+ fl.program_select(chan, sfid, 128 if chan == 9 else 0, 0)
3279
+
3280
+ if enable_reverb:
3281
+ fl.set_reverb(roomsize=reverb_param_dic['roomsize'],
3282
+ damping=reverb_param_dic['damping'],
3283
+ width=reverb_param_dic['width'],
3284
+ level=reverb_param_dic['level']
3285
+ )
3286
+
3287
+ """
3288
+ roomsize Reverb room size value (0.0-1.0)
3289
+ damping Reverb damping value (0.0-1.0)
3290
+ width Reverb width value (0.0-100.0)
3291
+ level Reverb level value (0.0-1.0)
3292
+ """
3293
+
3294
+ if enable_chorus:
3295
+ fl.set_chorus(nr=chorus_param_dic['nr'],
3296
+ level=chorus_param_dic['level'],
3297
+ speed=chorus_param_dic['speed'],
3298
+ depth=chorus_param_dic['depth'],
3299
+ type=chorus_param_dic['type']
3300
+ )
3301
+
3302
+ """
3303
+ nr Chorus voice count (0-99, CPU time consumption proportional to this value)
3304
+ level Chorus level (0.0-10.0)
3305
+ speed Chorus speed in Hz (0.29-5.0)
3306
+ depth_ms Chorus depth (max value depends on synth sample rate, 0.0-21.0 is safe for sample rate values up to 96KHz)
3307
+ type Chorus waveform type (0=sine, 1=triangle)
3308
+ """
3309
+
3310
+ # Playback vars
3311
+ tempo = int((60 / 120) * 1e6) # default 120bpm
3312
+ last_t = 0
3313
+ ss = np.empty((0, 2), dtype=np.int16)
3314
+
3315
+ for name, cur_t, *data in events:
3316
+ # compute how many samples have passed since the last event
3317
+ delta_ticks = cur_t - last_t
3318
+ last_t = cur_t
3319
+ dt_seconds = (delta_ticks / ticks_per_beat) * (tempo / 1e6)
3320
+ sample_len = int(dt_seconds * sample_rate)
3321
+ if sample_len > 0:
3322
+ buf = fl.get_samples(sample_len).reshape(-1, 2)
3323
+ ss = np.concatenate([ss, buf], axis=0)
3324
+
3325
+ # Dispatch every known event
3326
+ if name == "note_on" and data[2] > 0:
3327
+ chan, note, vel = data
3328
+ fl.noteon(chan, note, vel)
3329
+
3330
+ elif name == "note_off" or (name == "note_on" and data[2] == 0):
3331
+ chan, note = data[:2]
3332
+ fl.noteoff(chan, note)
3333
+
3334
+ elif name == "patch_change":
3335
+ chan, patch = data[:2]
3336
+ bank = 128 if chan == 9 else 0
3337
+ fl.program_select(chan, sfid, bank, patch)
3338
+
3339
+ elif name == "control_change":
3340
+ chan, ctrl, val = data[:3]
3341
+ fl.cc(chan, ctrl, val)
3342
+
3343
+ elif name == "key_after_touch":
3344
+ chan, note, vel = data
3345
+ # fl.key_pressure(chan, note, vel)
3346
+
3347
+ elif name == "channel_after_touch":
3348
+ chan, vel = data
3349
+ # fl.channel_pressure(chan, vel)
3350
+
3351
+ elif name == "pitch_wheel_change":
3352
+ chan, wheel = data
3353
+ fl.pitch_bend(chan, wheel)
3354
+
3355
+ elif name == "song_position":
3356
+ # song_pos = data[0]; # often not needed for playback
3357
+ pass
3358
+
3359
+ elif name == "song_select":
3360
+ # song_number = data[0]
3361
+ pass
3362
+
3363
+ elif name == "tune_request":
3364
+ # typically resets tuning; FS handles internally
3365
+ pass
3366
+
3367
+ elif name in ("sysex_f0", "sysex_f7"):
3368
+ raw_bytes = data[0]
3369
+ # fl.sysex(raw_bytes)
3370
+ pass
3371
+
3372
+ # Meta events & others—no direct audio effect, so we skip or log
3373
+ elif name in (
3374
+ "set_tempo", # handled below
3375
+ "end_track",
3376
+ "text_event", "text_event_08", "text_event_09", "text_event_0a",
3377
+ "text_event_0b", "text_event_0c", "text_event_0d", "text_event_0e", "text_event_0f",
3378
+ "copyright_text_event", "track_name", "instrument_name",
3379
+ "lyric", "marker", "cue_point",
3380
+ "smpte_offset", "time_signature", "key_signature",
3381
+ "sequencer_specific", "raw_meta_event"
3382
+ ):
3383
+ if name == "set_tempo":
3384
+ tempo = data[0]
3385
+ # else: skip all other meta & text; you could hook in logging here
3386
+ continue
3387
+
3388
+ else:
3389
+ # unknown event type
3390
+ continue
3391
+
3392
+ # Cleanup synth
3393
+ fl.delete()
3394
+
3395
+ if ss.size:
3396
+ maxv = np.abs(ss).max()
3397
+ if maxv:
3398
+ ss = (ss / maxv) * np.iinfo(np.int16).max
3399
+ ss = ss.astype(np.int16)
3400
+
3401
+ # Optional trimming of trailing silence
3402
+ if trim_silence and ss.size:
3403
+ thresh = np.std(np.abs(ss)) * silence_threshold
3404
+ idx = np.where(np.abs(ss) > thresh)[0]
3405
+ if idx.size:
3406
+ ss = ss[: idx[-1] + 1]
3407
+
3408
+ # For Gradio you might want raw int16 PCM
3409
+ if output_for_gradio:
3410
+ return ss
3411
+
3412
+ # Swap to (channels, samples) and normalize for playback
3413
+ ss = ss.T
3414
+ raw_audio = normalize_audio(ss, target_level_db=volume_level_db)
3415
+
3416
+ # Optionally write WAV to disk
3417
+ if write_audio_to_WAV:
3418
+ wav_name = midi_file.rsplit('.', 1)[0] + '.wav'
3419
+ if output_WAV_name != '':
3420
+ wav_name = output_WAV_name
3421
+ pcm = np.int16(raw_audio.T / np.max(np.abs(raw_audio)) * 32767)
3422
+ with wave.open(wav_name, 'wb') as wf:
3423
+ wf.setframerate(sample_rate)
3424
+ wf.setsampwidth(2)
3425
+ wf.setnchannels(pcm.shape[1])
3426
+ wf.writeframes(pcm.tobytes())
3427
+
3428
+ return raw_audio
3429
 
3430
  else:
3431
  return None
3432
 
3433
+ #===============================================================================
3434
+
3435
+ def midi_to_colab_audio(midi_file,
3436
+ soundfont_path='/usr/share/sounds/sf2/FluidR3_GM.sf2',
3437
+ sample_rate=16000,
3438
+ volume_level_db=-1,
3439
  trim_silence=True,
3440
  silence_threshold=0.1,
3441
+ enable_reverb=False,
3442
+ reverb_param_dic={'roomsize': 0,
3443
+ 'damping': 0,
3444
+ 'width': 0,
3445
+ 'level': 0
3446
+ },
3447
+ enable_chorus=False,
3448
+ chorus_param_dic={'nr': 0,
3449
+ 'level': 0,
3450
+ 'speed': 0.1,
3451
+ 'depth': 0,
3452
+ 'type': 0},
3453
  output_for_gradio=False,
3454
+ write_audio_to_WAV=False,
3455
+ output_WAV_name=''
3456
+ ):
3457
+ """
 
3458
  Returns raw audio to pass to IPython.disaply.Audio func
3459
 
3460
  Example usage:
 
3462
  from IPython.display import Audio
3463
 
3464
  display(Audio(raw_audio, rate=16000, normalize=False))
3465
+ """
 
 
 
 
 
 
 
 
 
 
3466
 
3467
+ # Read and decode MIDI → opus event list
3468
+ ticks_per_beat, *tracks = midi2opus(open(midi_file, 'rb').read())
3469
+ if not tracks:
3470
+ return None
3471
 
3472
+ # Flatten & convert delta-times to absolute-time
3473
+ events = []
3474
+ for track in tracks:
3475
+ abs_t = 0
3476
+ for name, dt, *data in track:
3477
+ abs_t += dt
3478
+ events.append([name, abs_t, *data])
3479
+ events.sort(key=lambda e: e[1])
3480
+
3481
+ # Setup FluidSynth
3482
+ fl = Synth(samplerate=float(sample_rate))
3483
+ sfid = fl.sfload(soundfont_path)
3484
+ for chan in range(16):
3485
+ # channel 9 = percussion GM bank 128
3486
+ fl.program_select(chan, sfid, 128 if chan == 9 else 0, 0)
3487
+
3488
+ if enable_reverb:
3489
+ fl.set_reverb(roomsize=reverb_param_dic['roomsize'],
3490
+ damping=reverb_param_dic['damping'],
3491
+ width=reverb_param_dic['width'],
3492
+ level=reverb_param_dic['level']
3493
+ )
3494
+
3495
+ """
3496
+ roomsize Reverb room size value (0.0-1.0)
3497
+ damping Reverb damping value (0.0-1.0)
3498
+ width Reverb width value (0.0-100.0)
3499
+ level Reverb level value (0.0-1.0)
3500
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3501
 
3502
+ if enable_chorus:
3503
+ fl.set_chorus(nr=chorus_param_dic['nr'],
3504
+ level=chorus_param_dic['level'],
3505
+ speed=chorus_param_dic['speed'],
3506
+ depth=chorus_param_dic['depth'],
3507
+ type=chorus_param_dic['type']
3508
+ )
3509
+
3510
+ """
3511
+ nr Chorus voice count (0-99, CPU time consumption proportional to this value)
3512
+ level Chorus level (0.0-10.0)
3513
+ speed Chorus speed in Hz (0.29-5.0)
3514
+ depth_ms Chorus depth (max value depends on synth sample rate, 0.0-21.0 is safe for sample rate values up to 96KHz)
3515
+ type Chorus waveform type (0=sine, 1=triangle)
3516
+ """
3517
+ # Playback vars
3518
+ tempo = int((60 / 120) * 1e6) # default 120bpm
3519
+ last_t = 0
3520
+ ss = np.empty((0, 2), dtype=np.int16)
3521
+
3522
+ for name, cur_t, *data in events:
3523
+ # compute how many samples have passed since the last event
3524
+ delta_ticks = cur_t - last_t
3525
+ last_t = cur_t
3526
+ dt_seconds = (delta_ticks / ticks_per_beat) * (tempo / 1e6)
3527
+ sample_len = int(dt_seconds * sample_rate)
3528
+ if sample_len > 0:
3529
+ buf = fl.get_samples(sample_len).reshape(-1, 2)
3530
+ ss = np.concatenate([ss, buf], axis=0)
3531
+
3532
+ # Dispatch every known event
3533
+ if name == "note_on" and data[2] > 0:
3534
+ chan, note, vel = data
3535
+ fl.noteon(chan, note, vel)
3536
+
3537
+ elif name == "note_off" or (name == "note_on" and data[2] == 0):
3538
+ chan, note = data[:2]
3539
+ fl.noteoff(chan, note)
3540
+
3541
+ elif name == "patch_change":
3542
+ chan, patch = data[:2]
3543
+ bank = 128 if chan == 9 else 0
3544
+ fl.program_select(chan, sfid, bank, patch)
3545
+
3546
+ elif name == "control_change":
3547
+ chan, ctrl, val = data[:3]
3548
+ fl.cc(chan, ctrl, val)
3549
+
3550
+ elif name == "key_after_touch":
3551
+ chan, note, vel = data
3552
+ # fl.key_pressure(chan, note, vel)
3553
+
3554
+ elif name == "channel_after_touch":
3555
+ chan, vel = data
3556
+ # fl.channel_pressure(chan, vel)
3557
+
3558
+ elif name == "pitch_wheel_change":
3559
+ chan, wheel = data
3560
+ fl.pitch_bend(chan, wheel)
3561
+
3562
+ elif name == "song_position":
3563
+ # song_pos = data[0]; # often not needed for playback
3564
+ pass
3565
+
3566
+ elif name == "song_select":
3567
+ # song_number = data[0]
3568
+ pass
3569
+
3570
+ elif name == "tune_request":
3571
+ # typically resets tuning; FS handles internally
3572
+ pass
3573
+
3574
+ elif name in ("sysex_f0", "sysex_f7"):
3575
+ raw_bytes = data[0]
3576
+ # fl.sysex(raw_bytes)
3577
+ pass
3578
+
3579
+ # Meta events & others—no direct audio effect, so we skip or log
3580
+ elif name in (
3581
+ "set_tempo", # handled below
3582
+ "end_track",
3583
+ "text_event", "text_event_08", "text_event_09", "text_event_0a",
3584
+ "text_event_0b", "text_event_0c", "text_event_0d", "text_event_0e", "text_event_0f",
3585
+ "copyright_text_event", "track_name", "instrument_name",
3586
+ "lyric", "marker", "cue_point",
3587
+ "smpte_offset", "time_signature", "key_signature",
3588
+ "sequencer_specific", "raw_meta_event"
3589
+ ):
3590
+ if name == "set_tempo":
3591
+ tempo = data[0]
3592
+ # else: skip all other meta & text; you could hook in logging here
3593
+ continue
3594
 
3595
+ else:
3596
+ # unknown event type
3597
+ continue
3598
 
3599
+ # Cleanup synth
3600
+ fl.delete()
3601
 
3602
+ if ss.size:
3603
+ maxv = np.abs(ss).max()
3604
+ if maxv:
3605
+ ss = (ss / maxv) * np.iinfo(np.int16).max
3606
+ ss = ss.astype(np.int16)
3607
 
3608
+ # Optional trimming of trailing silence
3609
+ if trim_silence and ss.size:
3610
+ thresh = np.std(np.abs(ss)) * silence_threshold
3611
+ idx = np.where(np.abs(ss) > thresh)[0]
3612
+ if idx.size:
3613
+ ss = ss[: idx[-1] + 1]
3614
 
3615
+ # For Gradio you might want raw int16 PCM
3616
+ if output_for_gradio:
3617
+ return ss
3618
 
3619
+ # Swap to (channels, samples) and normalize for playback
3620
+ ss = ss.T
3621
+ raw_audio = normalize_audio(ss, target_level_db=volume_level_db)
3622
+
3623
+ # Optionally write WAV to disk
3624
+ if write_audio_to_WAV:
3625
+ wav_name = midi_file.rsplit('.', 1)[0] + '.wav'
3626
+ if output_WAV_name != '':
3627
+ wav_name = output_WAV_name
3628
+ pcm = np.int16(raw_audio.T / np.max(np.abs(raw_audio)) * 32767)
3629
+ with wave.open(wav_name, 'wb') as wf:
3630
  wf.setframerate(sample_rate)
3631
  wf.setsampwidth(2)
3632
+ wf.setnchannels(pcm.shape[1])
3633
+ wf.writeframes(pcm.tobytes())
3634
+
3635
+ return raw_audio
3636
 
 
 
 
 
 
3637
  #===================================================================================================================
x_transformer_2_3_1.py CHANGED
@@ -4,7 +4,7 @@
4
  #
5
  # Partial x-transformers code With useful modifications as a stand-alone Python module
6
  #
7
- # Version 1.0
8
  #
9
  # Original source code courtesy of lucidrains
10
  # https://github.com/lucidrains/x-transformers
@@ -13,7 +13,7 @@
13
  # Original version 2.3.1 / Commit 458bc12
14
  #
15
  # Project Los Angeles
16
- # Tegridy Code 2025
17
  #
18
  #===================================================================================================================
19
  #
@@ -4146,6 +4146,416 @@ class AutoregressiveWrapper(Module):
4146
  out, = unpack(out, ps, '* n')
4147
 
4148
  return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4149
 
4150
  def compute_accuracy(self, logits, labels):
4151
 
 
4
  #
5
  # Partial x-transformers code With useful modifications as a stand-alone Python module
6
  #
7
+ # Version 3.0
8
  #
9
  # Original source code courtesy of lucidrains
10
  # https://github.com/lucidrains/x-transformers
 
13
  # Original version 2.3.1 / Commit 458bc12
14
  #
15
  # Project Los Angeles
16
+ # Tegridy Code 2026
17
  #
18
  #===================================================================================================================
19
  #
 
4146
  out, = unpack(out, ps, '* n')
4147
 
4148
  return out
4149
+
4150
+ @torch.no_grad()
4151
+ @eval_decorator
4152
+ def generate_masked(
4153
+ self,
4154
+ prompts,
4155
+ seq_len,
4156
+ eos_token = None,
4157
+ temperature = 1.,
4158
+ prompt_lens: Tensor | None = None,
4159
+ filter_logits_fn: str | Callable = top_k,
4160
+ restrict_to_max_seq_len = True,
4161
+ amateur_model: Module | Tuple[Module] | None = None,
4162
+ filter_kwargs: dict = dict(),
4163
+ contrastive_decode_kwargs: dict | Tuple[dict] = dict(
4164
+ beta = 0.5,
4165
+ alpha = 0.1
4166
+ ),
4167
+ cache_kv = True,
4168
+ return_prime=False,
4169
+ verbose=True,
4170
+ masked_token_ids: list[int] | Tensor | None = None,
4171
+ **kwargs
4172
+ ):
4173
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
4174
+
4175
+ prompts, ps = pack([prompts], '* n')
4176
+
4177
+ b, t = prompts.shape
4178
+
4179
+ # handle filter logits fn given as string
4180
+ if isinstance(filter_logits_fn, str):
4181
+ assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
4182
+ filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
4183
+
4184
+ # prepare masked token ids tensor (if any)
4185
+ if masked_token_ids is not None:
4186
+ if not torch.is_tensor(masked_token_ids):
4187
+ masked_token_ids = torch.tensor(masked_token_ids, dtype=torch.long, device=device)
4188
+ else:
4189
+ masked_token_ids = masked_token_ids.to(device=device, dtype=torch.long)
4190
+ # keep unique and non-negative
4191
+ masked_token_ids = torch.unique(masked_token_ids)
4192
+ # remove any ids that are out of range (optional safety)
4193
+ # we can't know vocab size here, so we only remove negative ids
4194
+ masked_token_ids = masked_token_ids[masked_token_ids >= 0]
4195
+ else:
4196
+ masked_token_ids = None
4197
+
4198
+ # handle variable lengthed prompts (prefixes)
4199
+ seq_start_pos = None
4200
+ if exists(prompt_lens):
4201
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
4202
+ seq_start_pos = t - prompt_lens
4203
+
4204
+ # output from which sampled tokens appended to
4205
+ out = prompts
4206
+
4207
+ if verbose:
4208
+ print("Generating sequence of max length:", seq_len)
4209
+
4210
+ # kv caches
4211
+ cache = None
4212
+
4213
+ # if doing contrastive decoding, turn off filter automatically
4214
+ if exists(amateur_model):
4215
+ amateur_model = cast_tuple(amateur_model)
4216
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
4217
+
4218
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
4219
+
4220
+ amateur_caches = [None] * len(amateur_model)
4221
+ filter_logits_fn = identity
4222
+
4223
+ for i, module in enumerate(amateur_model):
4224
+ if isinstance(module, AutoregressiveWrapper):
4225
+ amateur_model[i] = module.net
4226
+
4227
+ module.eval()
4228
+
4229
+ # sampling up to seq_len
4230
+ for sl in range(seq_len):
4231
+
4232
+ if restrict_to_max_seq_len:
4233
+ max_len_exceeded = out.shape[-1] > max_seq_len
4234
+
4235
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
4236
+
4237
+ x = out[:, -max_seq_len:]
4238
+
4239
+ if exists(cache):
4240
+ for inter in cache.attn_intermediates:
4241
+ if inter.layer_type == 'a':
4242
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
4243
+
4244
+ logits, new_cache = self.net(
4245
+ x,
4246
+ return_intermediates = True,
4247
+ cache = cache,
4248
+ seq_start_pos = seq_start_pos,
4249
+ **kwargs
4250
+ )
4251
+
4252
+ if cache_kv and self.net.can_cache_kv:
4253
+ cache = new_cache
4254
+
4255
+ logits = logits[:, -1]
4256
+
4257
+ # handle contrastive decoding, Li et al.
4258
+ # https://arxiv.org/abs/2210.15097
4259
+ if exists(amateur_model):
4260
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
4261
+ amateur_logits, next_amateur_cache = amateur(
4262
+ x,
4263
+ return_intermediates = True,
4264
+ cache = amateur_cache,
4265
+ seq_start_pos = seq_start_pos,
4266
+ **kwargs
4267
+ )
4268
+
4269
+ amateur_logits = amateur_logits[:, -1]
4270
+
4271
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
4272
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
4273
+
4274
+ if cache_kv and amateur.can_cache_kv:
4275
+ amateur_caches[i] = next_amateur_cache
4276
+
4277
+ # --- apply masked token ids here (after contrastive decoding, before filtering/sampling)
4278
+ if masked_token_ids is not None and masked_token_ids.numel() > 0:
4279
+ # safety: ensure indices are within logits' vocab dimension
4280
+ vocab_size = logits.shape[-1]
4281
+ valid_masked = masked_token_ids[masked_token_ids < vocab_size]
4282
+ if valid_masked.numel() > 0:
4283
+ # set logits for masked ids to a very large negative value
4284
+ neg_inf = -1e9
4285
+ # logits shape: (batch, vocab)
4286
+ logits[:, valid_masked] = neg_inf
4287
+
4288
+ # filter by top_k, top_p (nucleus), top_a, or custom
4289
+ if greedy:
4290
+ sample = logits.argmax(dim = -1, keepdim = True)
4291
+ else:
4292
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
4293
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
4294
+ sample = torch.multinomial(probs, 1)
4295
+
4296
+ # concat sample
4297
+ out = torch.cat((out, sample), dim=-1)
4298
+
4299
+ if verbose:
4300
+ if sl % 32 == 0:
4301
+ print(sl, '/', seq_len)
4302
+
4303
+ if not exists(eos_token):
4304
+ continue
4305
+
4306
+ is_eos_tokens = (out == eos_token)
4307
+
4308
+ if is_eos_tokens.any(dim = -1).all():
4309
+
4310
+ if verbose:
4311
+ print('Model called the end of sequence at:', sl, '/', seq_len)
4312
+
4313
+ break
4314
+
4315
+ if exists(eos_token):
4316
+ # mask out everything after the eos tokens
4317
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
4318
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
4319
+ out = out.masked_fill(mask, self.pad_value)
4320
+
4321
+ if return_prime:
4322
+ out = out[:, :]
4323
+
4324
+ else:
4325
+ out = out[:, t:]
4326
+
4327
+ out, = unpack(out, ps, '* n')
4328
+
4329
+ return out
4330
+
4331
+ @torch.no_grad()
4332
+ @eval_decorator
4333
+ def generate_biased(
4334
+ self,
4335
+ prompts,
4336
+ seq_len,
4337
+ eos_token = None,
4338
+ temperature = 1.,
4339
+ prompt_lens: Tensor | None = None,
4340
+ filter_logits_fn: str | Callable = top_k,
4341
+ restrict_to_max_seq_len = True,
4342
+ amateur_model: Module | Tuple[Module] | None = None,
4343
+ filter_kwargs: dict = dict(),
4344
+ contrastive_decode_kwargs: dict | Tuple[dict] = dict(
4345
+ beta = 0.5,
4346
+ alpha = 0.1
4347
+ ),
4348
+ cache_kv = True,
4349
+ return_prime=False,
4350
+ verbose=True,
4351
+ logit_bias: dict | Tensor | None = None, # <-- new parameter
4352
+ **kwargs
4353
+ ):
4354
+ """
4355
+ Autoregressive generation with optional additive logit bias.
4356
+
4357
+ logit_bias:
4358
+ - dict[token_id -> float] OR
4359
+ - torch.Tensor of shape (vocab,) OR (batch, vocab)
4360
+ """
4361
+
4362
+ max_seq_len, greedy, device = self.max_seq_len, temperature == 0., prompts.device
4363
+
4364
+ prompts, ps = pack([prompts], '* n')
4365
+
4366
+ b, t = prompts.shape
4367
+
4368
+ # handle filter logits fn given as string
4369
+ if isinstance(filter_logits_fn, str):
4370
+ assert filter_logits_fn in FILTER_LOGITS_FN, f"only {join(FILTER_LOGITS_FN.keys())} are available"
4371
+ filter_logits_fn = FILTER_LOGITS_FN[filter_logits_fn]
4372
+
4373
+ # handle variable lengthed prompts (prefixes)
4374
+ seq_start_pos = None
4375
+ if exists(prompt_lens):
4376
+ prompts = align_right(prompts, prompt_lens, pad_id = self.pad_value)
4377
+ seq_start_pos = t - prompt_lens
4378
+
4379
+ # output from which sampled tokens appended to
4380
+ out = prompts
4381
+
4382
+ if verbose:
4383
+ print("Generating sequence of max length:", seq_len)
4384
+
4385
+ # kv caches
4386
+ cache = None
4387
+
4388
+ # if doing contrastive decoding, turn off filter automatically
4389
+ if exists(amateur_model):
4390
+ amateur_model = cast_tuple(amateur_model)
4391
+ contrastive_decode_kwargs = cast_tuple(contrastive_decode_kwargs)
4392
+ assert len(amateur_model) == len(contrastive_decode_kwargs)
4393
+ amateur_caches = [None] * len(amateur_model)
4394
+ filter_logits_fn = identity
4395
+ for i, module in enumerate(amateur_model):
4396
+ if isinstance(module, AutoregressiveWrapper):
4397
+ amateur_model[i] = module.net
4398
+ module.eval()
4399
+
4400
+ # -------------------------
4401
+ # Prepare logit_bias (robust vocab-size detection)
4402
+ # -------------------------
4403
+ prepared_bias = None
4404
+ lazy_build_bias_from_dict = None
4405
+
4406
+ if exists(logit_bias):
4407
+ if isinstance(logit_bias, dict):
4408
+ # try to determine vocab size from model without using logits
4409
+ vocab_size = None
4410
+
4411
+ # common places to find vocab size
4412
+ try:
4413
+ if hasattr(self.net, "config") and getattr(self.net.config, "vocab_size", None) is not None:
4414
+ vocab_size = int(self.net.config.vocab_size)
4415
+ elif getattr(self.net, "vocab_size", None) is not None:
4416
+ vocab_size = int(self.net.vocab_size)
4417
+ else:
4418
+ # try to infer from embedding / output projection weights
4419
+ # huggingface style: get_output_embeddings() or embed_tokens or lm_head
4420
+ get_out = getattr(self.net, "get_output_embeddings", None)
4421
+ if callable(get_out) and get_out() is not None:
4422
+ vocab_size = int(get_out().weight.shape[0])
4423
+ elif hasattr(self.net, "embed_tokens"):
4424
+ vocab_size = int(self.net.embed_tokens.weight.shape[0])
4425
+ elif hasattr(self.net, "lm_head"):
4426
+ vocab_size = int(self.net.lm_head.weight.shape[0])
4427
+ except Exception:
4428
+ vocab_size = None
4429
+
4430
+ if vocab_size is not None:
4431
+ bias_vec = torch.zeros(int(vocab_size), device=device, dtype=torch.float32)
4432
+ for tok, val in logit_bias.items():
4433
+ tok_i = int(tok)
4434
+ if tok_i < 0 or tok_i >= vocab_size:
4435
+ raise IndexError(f"logit_bias token id {tok_i} out of range for vocab size {vocab_size}")
4436
+ bias_vec[tok_i] = float(val)
4437
+ prepared_bias = bias_vec
4438
+ else:
4439
+ # can't determine vocab size yet — build lazily after first logits are available
4440
+ lazy_build_bias_from_dict = {int(k): float(v) for k, v in logit_bias.items()}
4441
+
4442
+ elif isinstance(logit_bias, torch.Tensor):
4443
+ prepared_bias = logit_bias.to(device=device, dtype=torch.float32)
4444
+ else:
4445
+ raise TypeError("logit_bias must be dict or torch.Tensor")
4446
+
4447
+ # sampling up to seq_len
4448
+ for sl in range(seq_len):
4449
+
4450
+ if restrict_to_max_seq_len:
4451
+ max_len_exceeded = out.shape[-1] > max_seq_len
4452
+ assert not (cache_kv and max_len_exceeded and not self.net.can_cache_kv_outside_max_seq_len), \
4453
+ 'the network cannot use cached key values when decoding outside the max sequence length. most likely because you are using absolute positional embedding. you can switch to rotary embeddings to resolve this issue'
4454
+ x = out[:, -max_seq_len:]
4455
+ if exists(cache):
4456
+ for inter in cache.attn_intermediates:
4457
+ if inter.layer_type == 'a':
4458
+ inter.cached_kv = [t[..., -(max_seq_len - 1):, :] for t in inter.cached_kv]
4459
+ else:
4460
+ x = out
4461
+
4462
+ logits, new_cache = self.net(
4463
+ x,
4464
+ return_intermediates = True,
4465
+ cache = cache,
4466
+ seq_start_pos = seq_start_pos,
4467
+ **kwargs
4468
+ )
4469
+
4470
+ if cache_kv and self.net.can_cache_kv:
4471
+ cache = new_cache
4472
+
4473
+ logits = logits[:, -1] # shape (batch, vocab)
4474
+
4475
+ # If we couldn't build the bias earlier because vocab size was unknown,
4476
+ # build it now from the first logits tensor.
4477
+ if lazy_build_bias_from_dict is not None:
4478
+ vocab_size = logits.shape[-1]
4479
+ bias_vec = torch.zeros(vocab_size, device=device, dtype=torch.float32)
4480
+ for tok, val in lazy_build_bias_from_dict.items():
4481
+ if tok < 0 or tok >= vocab_size:
4482
+ raise IndexError(f"logit_bias token id {tok} out of range for vocab size {vocab_size}")
4483
+ bias_vec[tok] = val
4484
+ prepared_bias = bias_vec
4485
+ lazy_build_bias_from_dict = None # only build once
4486
+
4487
+ # handle contrastive decoding, Li et al.
4488
+ # https://arxiv.org/abs/2210.15097
4489
+ if exists(amateur_model):
4490
+ for i, (amateur, amateur_cache, amateur_contrastive_decode_kwargs) in enumerate(zip(amateur_model, amateur_caches, contrastive_decode_kwargs)):
4491
+ amateur_logits, next_amateur_cache = amateur(
4492
+ x,
4493
+ return_intermediates = True,
4494
+ cache = amateur_cache,
4495
+ seq_start_pos = seq_start_pos,
4496
+ **kwargs
4497
+ )
4498
+ amateur_logits = amateur_logits[:, -1]
4499
+ assert amateur_logits.shape == logits.shape, 'logits dimension are not the same between amateur and expert model'
4500
+ logits = contrastive_decode_fn(logits, amateur_logits, **amateur_contrastive_decode_kwargs)
4501
+ if cache_kv and amateur.can_cache_kv:
4502
+ amateur_caches[i] = next_amateur_cache
4503
+
4504
+ # -------------------------
4505
+ # Apply logit bias if provided
4506
+ # -------------------------
4507
+ if exists(prepared_bias):
4508
+ # prepared_bias can be (vocab,) or (batch, vocab)
4509
+ if prepared_bias.dim() == 1:
4510
+ # broadcast to batch
4511
+ logits = logits + prepared_bias.unsqueeze(0)
4512
+ elif prepared_bias.dim() == 2:
4513
+ # expect shape (batch, vocab)
4514
+ if prepared_bias.shape[0] != logits.shape[0]:
4515
+ raise ValueError("logit_bias tensor batch size must match logits batch size")
4516
+ logits = logits + prepared_bias
4517
+ else:
4518
+ raise ValueError("logit_bias tensor must be 1D (vocab,) or 2D (batch, vocab)")
4519
+
4520
+ # filter by top_k, top_p (nucleus), top_a, or custom
4521
+ if greedy:
4522
+ sample = logits.argmax(dim = -1, keepdim = True)
4523
+ else:
4524
+ filtered_logits = filter_logits_fn(logits, **filter_kwargs)
4525
+ probs = F.softmax(filtered_logits / temperature, dim=-1)
4526
+ sample = torch.multinomial(probs, 1)
4527
+
4528
+ # concat sample
4529
+ out = torch.cat((out, sample), dim=-1)
4530
+
4531
+ if verbose:
4532
+ if sl % 32 == 0:
4533
+ print(sl, '/', seq_len)
4534
+
4535
+ if not exists(eos_token):
4536
+ continue
4537
+
4538
+ is_eos_tokens = (out == eos_token)
4539
+
4540
+ if is_eos_tokens.any(dim = -1).all():
4541
+ if verbose:
4542
+ print('Model called the end of sequence at:', sl, '/', seq_len)
4543
+ break
4544
+
4545
+ if exists(eos_token):
4546
+ # mask out everything after the eos tokens
4547
+ shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
4548
+ mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
4549
+ out = out.masked_fill(mask, self.pad_value)
4550
+
4551
+ if return_prime:
4552
+ out = out[:, :]
4553
+ else:
4554
+ out = out[:, t:]
4555
+
4556
+ out, = unpack(out, ps, '* n')
4557
+
4558
+ return out
4559
 
4560
  def compute_accuracy(self, logits, labels):
4561