saumya-pailwan commited on
Commit
251510f
·
verified ·
1 Parent(s): 9ae1e69
Files changed (1) hide show
  1. app.py +42 -121
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- import spaces # Enables ZeroGPU on Hugging Face
3
  import gradio as gr
4
  import torch
5
  from dataclasses import asdict
@@ -25,10 +25,10 @@ LARGE_MODEL = "stanford-crfm/music-large-800k"
25
  model_card = ModelCard(
26
  name="Anticipatory Music Transformer",
27
  description=(
28
- "Generate musical accompaniment for your existing melody using the Anticipatory Music Transformer.\n"
29
- "Input: a MIDI file that includes a short section of accompaniment followed by a melody.\n"
30
- "Output: a new MIDI file with extended accompaniment that continues naturally with the melody.\n"
31
- "Use the control below to select how much of the input song is used as context and select the model size. \n\n"
32
  ),
33
  author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
34
  tags=["midi", "generation", "accompaniment"]
@@ -60,135 +60,59 @@ def load_amt_model(model_choice: str):
60
  return model
61
 
62
  def find_melody_program(mid, debug=False):
63
- """
64
- Automatically detect the melody track's program number from a MIDI file.
65
- Uses a balanced heuristic: pitch + note count + temporal coverage.
66
- """
67
  track_stats = []
68
- total_duration = 0
69
-
70
  for i, track in enumerate(mid.tracks):
71
  pitches, times = [], []
72
  current_time = 0
73
- current_program = None
74
- track_note_count = 0
75
-
76
  for msg in track:
77
- if msg.type not in ("note_on", "program_change"):
78
- continue
79
-
80
- current_time += msg.time
81
- if msg.type == "program_change":
82
- current_program = msg.program
83
- continue
84
-
85
- # note_on event
86
- if msg.velocity > 0:
87
  pitches.append(msg.note)
88
  times.append(current_time)
89
- track_note_count += 1
90
-
91
- # Early stop if enough notes gathered
92
- if track_note_count >= 100:
93
- break
94
-
95
- # Skip empty or trivial tracks
96
- if not pitches:
97
- continue
98
-
99
- # Compute duration for this track and update total_duration
100
- track_duration = max(times) - min(times)
101
- total_duration = max(total_duration, current_time)
102
-
103
- mean_pitch = sum(pitches) / len(pitches)
104
- polyphony = len(set(pitches)) / len(pitches)
105
- coverage = track_duration / total_duration if total_duration > 0 else 0
106
-
107
- track_stats.append((i, mean_pitch, len(pitches), current_program, polyphony, coverage))
108
 
109
  if not track_stats:
110
- return None, False
111
-
112
- if len(track_stats) == 1:
113
- prog = track_stats[0][3]
114
  if debug:
115
- if prog == 0:
116
- print("Single-track MIDI detected, program 0 (Acoustic Grand Piano) will be treated as melody.")
117
- else:
118
- print(f"Single-track MIDI detected, using program {prog or 'None'}")
119
- return prog, prog is not None
120
-
121
- candidates = [t for t in track_stats if t[3] is not None and t[3] > 0]
122
- has_valid_programs = len(candidates) > 0
123
- if not candidates:
124
- candidates = track_stats
125
-
126
- if debug:
127
- print(f"\nCandidates: {len(candidates)} tracks")
128
-
129
- max_notes = max(t[2] for t in candidates)
130
- max_pitch = max(t[1] for t in candidates)
131
- min_pitch = min(t[1] for t in candidates)
132
- pitch_span = max_pitch - min_pitch if max_pitch > min_pitch else 1
133
-
134
- best_score = -1
135
- best_program = None
136
- best_track = None
137
- best_pitch = None
138
-
139
- for t in candidates:
140
- idx, pitch, notes, prog, poly, coverage = t
141
- pitch_norm = (pitch - min_pitch) / pitch_span
142
- notes_norm = notes / max_notes
143
 
144
- score = (pitch_norm * 0.35) + (notes_norm * 0.35) + (coverage * 0.30)
145
 
146
- if poly < 0.15:
147
- score *= 0.95
148
- if 55 <= pitch <= 75:
149
- score *= 1.1
150
- if notes >= 30:
151
- score *= 1.05
152
- if coverage > 0.7:
153
- score *= 1.15
154
 
155
- if score > best_score:
156
- best_score = score
157
- best_program = prog
158
- best_track = idx
159
- best_pitch = pitch
160
 
161
- return best_program, has_valid_programs
 
 
 
 
162
 
163
 
164
  def auto_extract_melody(mid, debug=False):
165
- """
166
- Extract melody events from a MIDI object (already loaded via MidiFile).
167
- Optimized to avoid re-reading the file from disk.
168
- Returns: (all_events, melody_events)
169
- """
170
  events = midi_to_events(mid)
 
 
171
 
172
- melody_program, has_valid_program = find_melody_program(mid, debug=debug)
 
173
 
174
- if not has_valid_program or melody_program is None or melody_program == 0:
 
175
  if debug:
176
- print("No valid program changes in MIDI, using all events as melody")
177
- return events, events
178
-
179
- events, melody = extract_instruments(events, [melody_program])
180
-
181
- if len(melody) == 0:
182
  if debug:
183
- print("No events found for selected program, using all events")
184
- return events, events
185
-
186
- if debug:
187
- print(f"Extracted {len(melody)} melody events from program {melody_program}")
188
 
189
  return events, melody
190
 
191
- @spaces.GPU
192
  # Core generation
193
  def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
194
  """
@@ -199,7 +123,7 @@ def generate_accompaniment(midi_path: str, model_choice: str, history_length: fl
199
 
200
  # Parse MIDI correctly, then convert to events
201
  mid = MidiFile(midi_path)
202
- #print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
203
 
204
  # Automatically detect and extract melody
205
  all_events, melody = auto_extract_melody(mid, debug=True)
@@ -207,16 +131,18 @@ def generate_accompaniment(midi_path: str, model_choice: str, history_length: fl
207
  print("No melody detected; using all events")
208
  melody = all_events
209
 
210
- total_time = round(ops.max_time(all_events, seconds=True))
211
-
212
  # History portion
213
  history = ops.clip(all_events, 0, history_length, clip_duration=False)
 
 
 
 
 
214
 
215
- # Generate accompaniment for the remaining duration
216
  accompaniment = generate(
217
  model,
218
- start_time=history_length, # start after history
219
- end_time=total_time, # go to end
220
  inputs=history,
221
  controls=melody,
222
  top_p=0.95,
@@ -273,16 +199,11 @@ with gr.Blocks() as demo:
273
  choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
274
  value=MEDIUM_MODEL,
275
  label="Select AMT Model (Faster vs. Higher Quality)"
276
- ).set_info(
277
- "Choose the model size: Smaller models generate faster but may be less detailed. \n larger models produce richer, more expressive accompaniment."
278
  )
279
 
280
  history_slider = gr.Slider(
281
  minimum=1, maximum=10, step=1, value=5,
282
- label="Intro Duration for Context (sec)"
283
- ).set_info(
284
- "Controls how much of the beginning of your song is used as context for generation.\n "
285
- "A longer history helps the model better understand the style and rhythm before extending the accompaniment."
286
  )
287
 
288
  # Outputs (JSON FIRST)
 
1
  import os
2
+ #import spaces # Enables ZeroGPU on Hugging Face
3
  import gradio as gr
4
  import torch
5
  from dataclasses import asdict
 
25
  model_card = ModelCard(
26
  name="Anticipatory Music Transformer",
27
  description=(
28
+ "Generate musical accompaniment for your existing vamp using the Anticipatory Music Transformer. "
29
+ "Input: a MIDI file with a short accompaniment (vamp) followed by a melody line. "
30
+ "Output: a new MIDI file with extended accompaniment matching the melody continuation. "
31
+ "Use the sliders to choose model size and how much of the song is used as context."
32
  ),
33
  author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
34
  tags=["midi", "generation", "accompaniment"]
 
60
  return model
61
 
62
  def find_melody_program(mid, debug=False):
 
 
 
 
63
  track_stats = []
 
 
64
  for i, track in enumerate(mid.tracks):
65
  pitches, times = [], []
66
  current_time = 0
 
 
 
67
  for msg in track:
68
+ current_time += getattr(msg, "time", 0)
69
+ if msg.type == "note_on" and msg.velocity > 0:
 
 
 
 
 
 
 
 
70
  pitches.append(msg.note)
71
  times.append(current_time)
72
+ if pitches:
73
+ mean_pitch = sum(pitches) / len(pitches)
74
+ span = (max(times) - min(times)) or 1
75
+ density = len(pitches) / span
76
+ polyphony = len(set(pitches)) / len(pitches)
77
+ track_stats.append((i, mean_pitch, len(pitches), density, polyphony))
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if not track_stats:
 
 
 
 
80
  if debug:
81
+ print("No notes detected in any track.")
82
+ return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ melody_idx = sorted(track_stats, key=lambda x: (-x[1], -x[3]))[0][0]
85
 
86
+ return melody_idx
 
 
 
 
 
 
 
87
 
 
 
 
 
 
88
 
89
+ def get_program_number(mid, track_index):
90
+ for msg in mid.tracks[track_index]:
91
+ if msg.type == "program_change":
92
+ return msg.program
93
+ return None
94
 
95
 
96
  def auto_extract_melody(mid, debug=False):
 
 
 
 
 
97
  events = midi_to_events(mid)
98
+ melody_track = find_melody_program(mid, debug=debug)
99
+ melody_program = get_program_number(mid, melody_track)
100
 
101
+ if debug:
102
+ print(f"Melody Track: {melody_track} | Program: {melody_program}")
103
 
104
+ if melody_program is not None:
105
+ events, melody = extract_instruments(events, [melody_program])
106
  if debug:
107
+ print(f"Extracted {len(melody)} melody events from program {melody_program}")
108
+ else:
 
 
 
 
109
  if debug:
110
+ print("No program number found; using all events as melody.")
111
+ melody = events
 
 
 
112
 
113
  return events, melody
114
 
115
+ #@spaces.GPU
116
  # Core generation
117
  def generate_accompaniment(midi_path: str, model_choice: str, history_length: float):
118
  """
 
123
 
124
  # Parse MIDI correctly, then convert to events
125
  mid = MidiFile(midi_path)
126
+ print(f"Loaded MIDI file: type {mid.type} ({'single track' if mid.type == 0 else 'multi-track'})")
127
 
128
  # Automatically detect and extract melody
129
  all_events, melody = auto_extract_melody(mid, debug=True)
 
131
  print("No melody detected; using all events")
132
  melody = all_events
133
 
 
 
134
  # History portion
135
  history = ops.clip(all_events, 0, history_length, clip_duration=False)
136
+ start_time = ops.max_time(history, seconds=True)
137
+
138
+ mid_time = mid.length or 0
139
+ ops_time = ops.max_time(all_events, seconds=True)
140
+ total_time = round(max(mid_time, ops_time))
141
 
 
142
  accompaniment = generate(
143
  model,
144
+ start_time=history_length,
145
+ end_time=total_time,
146
  inputs=history,
147
  controls=melody,
148
  top_p=0.95,
 
199
  choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
200
  value=MEDIUM_MODEL,
201
  label="Select AMT Model (Faster vs. Higher Quality)"
 
 
202
  )
203
 
204
  history_slider = gr.Slider(
205
  minimum=1, maximum=10, step=1, value=5,
206
+ label="Select History Length (seconds)"
 
 
 
207
  )
208
 
209
  # Outputs (JSON FIRST)