lllindsey0615 commited on
Commit
84d0752
·
1 Parent(s): b01bf50

enable model selection

Browse files
Files changed (1) hide show
  1. app.py +54 -18
app.py CHANGED
@@ -7,42 +7,75 @@ from anticipation.tokenize import extract_instruments
7
  import torch
8
  from pyharp import *
9
 
 
 
 
 
 
10
  # Define the model card for PyHARP
11
  model_card = ModelCard(
12
  name="Anticipatory Music Transformer",
13
  description="Using Anticipatory Music Transformer (AMT) to generate accompaniment for a given MIDI file.",
14
  author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
15
  tags=["midi", "generation", "accompaniment"],
16
- midi_in=True, # PyHARP will automatically handle MIDI input
17
  midi_out=True
18
  )
19
 
20
- # Load the AMT model
21
- model_name = "stanford-crfm/music-medium-800k"
22
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
- model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
25
 
26
- # Function to generate accompaniment
27
- def generate_accompaniment(midi_file, selected_midi_program, start_time, end_time):
28
  # Convert MIDI to events
29
  events = midi_to_events(midi_file.name)
 
 
 
 
30
  # Clip events based on the selected time range
31
  clipped_events = ops.clip(events, start_time, end_time)
 
32
  # Normalize timeline (start at 0)
33
  clipped_events = ops.translate(clipped_events, -ops.min_time(clipped_events, seconds=False))
34
- # Extract the chosen melody instrument
35
- clipped_events, melody = extract_instruments(clipped_events, [selected_midi_program])
 
 
36
  # Prepare history (first 5 seconds of the segment)
37
  history = ops.clip(clipped_events, 0, 5, clip_duration=False)
38
- # Generate accompaniment using the AMT model
 
39
  accompaniment = generate(
40
  model, 0, end_time - start_time, inputs=history, controls=melody, top_p=0.98
41
  )
 
42
  # Normalize generated accompaniment
43
  accompaniment = ops.translate(accompaniment, -ops.min_time(accompaniment, seconds=False))
 
44
  # Combine accompaniment with melody
45
  output_events = ops.clip(ops.combine(accompaniment, melody), 0, end_time - start_time, clip_duration=True)
 
46
  # Convert back to MIDI
47
  output_midi = "generated_accompaniment.midi"
48
  mid = events_to_midi(output_events)
@@ -50,17 +83,20 @@ def generate_accompaniment(midi_file, selected_midi_program, start_time, end_tim
50
 
51
  return output_midi
52
 
53
-
54
- # PyHARP process function
55
- def process_fn(input_midi, selected_midi_program, start_time, end_time):
56
- output_midi = generate_accompaniment(input_midi, selected_midi_program, start_time, end_time)
57
  return output_midi, LabelList()
58
 
59
-
60
- # Build Gradio interface wrapped in PyHARP
61
  with gr.Blocks() as demo:
62
  components = [
63
- gr.Slider(0, 127, step=1, value=53, label="Select Melody Instrument (MIDI Program Number)"),
 
 
 
 
64
  gr.Slider(0, 30, step=1, label="Start Time (seconds)"),
65
  gr.Slider(0, 30, step=1, label="End Time (seconds)"),
66
  ]
@@ -72,4 +108,4 @@ with gr.Blocks() as demo:
72
  )
73
 
74
  demo.queue()
75
- demo.launch(share=True, show_error=True)
 
7
  import torch
8
  from pyharp import *
9
 
10
+ # === Define AMT Model Checkpoints ===
11
+ SMALL_MODEL = "stanford-crfm/music-small-800k" # Faster inference, worse quality
12
+ MEDIUM_MODEL = "stanford-crfm/music-medium-800k" # Slower inference, better quality
13
+ LARGE_MODEL = "stanford-crfm/music-large-800k" # Slowest inference, best quality
14
+
15
  # Define the model card for PyHARP
16
  model_card = ModelCard(
17
  name="Anticipatory Music Transformer",
18
  description="Using Anticipatory Music Transformer (AMT) to generate accompaniment for a given MIDI file.",
19
  author="John Thickstun, David Hall, Chris Donahue, Percy Liang",
20
  tags=["midi", "generation", "accompaniment"],
21
+ midi_in=True, # PyHARP automatically handles MIDI input
22
  midi_out=True
23
  )
24
 
25
+ # === Function to Load AMT Model Based on Selection ===
26
+ def load_amt_model(model_choice):
27
+ """Loads the selected AMT model."""
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ return AutoModelForCausalLM.from_pretrained(model_choice).to(device)
30
+
31
+ # === Function to Detect the Melody Program Automatically ===
32
+ def detect_melody_program(midi_file):
33
+ """Automatically detects the only MIDI program in the input file."""
34
+ events = midi_to_events(midi_file.name)
35
+ instrument_programs = list(ops.get_instruments(events).keys())
36
+
37
+ if len(instrument_programs) == 1:
38
+ return instrument_programs[0] # Return the only available program
39
+ elif len(instrument_programs) > 1:
40
+ return min(instrument_programs) # Pick the lowest-numbered program
41
+ else:
42
+ return 0 # Default to Acoustic Grand Piano if no program found
43
 
44
+ # === Function to Generate Accompaniment ===
45
+ def generate_accompaniment(midi_file, model_choice, start_time, end_time):
46
+ """Generates accompaniment using the selected AMT model."""
47
+ # Load selected AMT model
48
+ model = load_amt_model(model_choice)
49
 
 
 
50
  # Convert MIDI to events
51
  events = midi_to_events(midi_file.name)
52
+
53
+ # Automatically detect the melody program
54
+ melody_program = detect_melody_program(midi_file)
55
+
56
  # Clip events based on the selected time range
57
  clipped_events = ops.clip(events, start_time, end_time)
58
+
59
  # Normalize timeline (start at 0)
60
  clipped_events = ops.translate(clipped_events, -ops.min_time(clipped_events, seconds=False))
61
+
62
+ # Extract the melody instrument automatically
63
+ clipped_events, melody = extract_instruments(clipped_events, [melody_program])
64
+
65
  # Prepare history (first 5 seconds of the segment)
66
  history = ops.clip(clipped_events, 0, 5, clip_duration=False)
67
+
68
+ # Generate accompaniment using AMT
69
  accompaniment = generate(
70
  model, 0, end_time - start_time, inputs=history, controls=melody, top_p=0.98
71
  )
72
+
73
  # Normalize generated accompaniment
74
  accompaniment = ops.translate(accompaniment, -ops.min_time(accompaniment, seconds=False))
75
+
76
  # Combine accompaniment with melody
77
  output_events = ops.clip(ops.combine(accompaniment, melody), 0, end_time - start_time, clip_duration=True)
78
+
79
  # Convert back to MIDI
80
  output_midi = "generated_accompaniment.midi"
81
  mid = events_to_midi(output_events)
 
83
 
84
  return output_midi
85
 
86
+ # === PyHARP Process Function ===
87
+ def process_fn(input_midi, model_choice, start_time, end_time):
88
+ """Processes the input and runs AMT with selected model."""
89
+ output_midi = generate_accompaniment(input_midi, model_choice, start_time, end_time)
90
  return output_midi, LabelList()
91
 
92
+ # === Build Gradio Interface with Model Selection ===
 
93
  with gr.Blocks() as demo:
94
  components = [
95
+ gr.Dropdown(
96
+ choices=[SMALL_MODEL, MEDIUM_MODEL, LARGE_MODEL],
97
+ value=MEDIUM_MODEL,
98
+ label="Select AMT Model (Faster vs. Higher Quality)"
99
+ ),
100
  gr.Slider(0, 30, step=1, label="Start Time (seconds)"),
101
  gr.Slider(0, 30, step=1, label="End Time (seconds)"),
102
  ]
 
108
  )
109
 
110
  demo.queue()
111
+ demo.launch(share = True,show_error=True)