lllindsey0615 commited on
Commit
3f427f8
Β·
1 Parent(s): 0f15908

debug for amt model loading with safetensors 2

Browse files
Files changed (1) hide show
  1. app.py +44 -18
app.py CHANGED
@@ -41,14 +41,8 @@ def load_amt_model(model_choice):
41
  return model
42
  '''
43
 
44
- import os
45
- import traceback
46
- from transformers import AutoModelForCausalLM, AutoConfig
47
- import torch
48
- from safetensors.torch import load_file
49
 
50
  def load_amt_model(model_choice):
51
- """Loads and caches the AMT model inside the worker process."""
52
  if model_choice in model_cache:
53
  return model_cache[model_choice]
54
 
@@ -56,30 +50,19 @@ def load_amt_model(model_choice):
56
 
57
  try:
58
  print(f"πŸ“Œ Loading model: {model_choice}")
59
-
60
- # For small and medium models, use the original loading method
61
  if model_choice in ["stanford-crfm/music-small-800k", "stanford-crfm/music-medium-800k"]:
62
  model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
63
- print(f"βœ… Successfully loaded {model_choice} using standard PyTorch method.")
64
-
65
- # For large model, use SafeTensors method
66
  elif model_choice == "stanford-crfm/music-large-800k":
67
- print(f"πŸ“Œ Detected SafeTensors format for {model_choice}, loading manually...")
68
-
69
- # Load model config
70
  config = AutoConfig.from_pretrained(model_choice)
71
-
72
- # Load SafeTensors manually
73
  safetensor_path = os.path.join("models", model_choice, "model.safetensors")
74
-
75
  if not os.path.exists(safetensor_path):
76
  raise FileNotFoundError(f"❌ SafeTensors file not found: {safetensor_path}")
77
 
 
78
  state_dict = load_file(safetensor_path)
79
  model = AutoModelForCausalLM.from_config(config) # Initialize model
80
  model.load_state_dict(state_dict) # Load weights
81
  model.to(device)
82
- print(f"βœ… Successfully loaded {model_choice} using SafeTensors.")
83
 
84
  else:
85
  raise ValueError(f"❌ Unknown model choice: {model_choice}")
@@ -94,6 +77,7 @@ def load_amt_model(model_choice):
94
 
95
 
96
  @spaces.GPU
 
97
  def generate_accompaniment(midi_file, model_choice, selected_midi_program, history_length):
98
  """Generates accompaniment for the entire MIDI input, conditioned on the user-selected history length."""
99
 
@@ -129,6 +113,48 @@ def generate_accompaniment(midi_file, model_choice, selected_midi_program, histo
129
  mid.save(output_midi)
130
 
131
  return output_midi, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
 
134
  def process_fn(input_midi, model_choice, selected_midi_program, history_length):
 
41
  return model
42
  '''
43
 
 
 
 
 
 
44
 
45
  def load_amt_model(model_choice):
 
46
  if model_choice in model_cache:
47
  return model_cache[model_choice]
48
 
 
50
 
51
  try:
52
  print(f"πŸ“Œ Loading model: {model_choice}")
 
 
53
  if model_choice in ["stanford-crfm/music-small-800k", "stanford-crfm/music-medium-800k"]:
54
  model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
 
 
 
55
  elif model_choice == "stanford-crfm/music-large-800k":
 
 
 
56
  config = AutoConfig.from_pretrained(model_choice)
 
 
57
  safetensor_path = os.path.join("models", model_choice, "model.safetensors")
 
58
  if not os.path.exists(safetensor_path):
59
  raise FileNotFoundError(f"❌ SafeTensors file not found: {safetensor_path}")
60
 
61
+ # Load SafeTensors model weights
62
  state_dict = load_file(safetensor_path)
63
  model = AutoModelForCausalLM.from_config(config) # Initialize model
64
  model.load_state_dict(state_dict) # Load weights
65
  model.to(device)
 
66
 
67
  else:
68
  raise ValueError(f"❌ Unknown model choice: {model_choice}")
 
77
 
78
 
79
  @spaces.GPU
80
+ '''
81
  def generate_accompaniment(midi_file, model_choice, selected_midi_program, history_length):
82
  """Generates accompaniment for the entire MIDI input, conditioned on the user-selected history length."""
83
 
 
113
  mid.save(output_midi)
114
 
115
  return output_midi, None
116
+ '''
117
+ def generate_accompaniment(midi_file, model_choice, selected_midi_program, history_length):
118
+ model = load_amt_model(model_choice)
119
+ # Ensure model loaded successfully before using it
120
+ if model is None:
121
+ print(" Model loading failed. Returning error.")
122
+ return None, "⚠️ Model failed to load."
123
+
124
+ print(f"Model loaded successfully: {model_choice}")
125
+
126
+ events = midi_to_events(midi_file.name)
127
+ total_time = round(ops.max_time(events, seconds=True))
128
+ events, melody = extract_instruments(events, [selected_midi_program])
129
+
130
+ if not melody:
131
+ print("No melody detected. Please select a valid MIDI program.")
132
+ return None, "Please select a valid MIDI program that contains events."
133
+
134
+ history = ops.clip(events, 0, history_length, clip_duration=False)
135
+
136
+ # Generate accompaniment for the remaining duration
137
+ accompaniment = generate(
138
+ model,
139
+ history_length,
140
+ total_time,
141
+ inputs=history,
142
+ controls=melody,
143
+ top_p=0.95,
144
+ debug=True # Enable debug mode if supported
145
+ )
146
+
147
+ # Combine accompaniment with the melody
148
+ output_events = ops.clip(ops.combine(accompaniment, melody), 0, total_time, clip_duration=True)
149
+
150
+ # Convert back to MIDI
151
+ output_midi = "generated_accompaniment_huggingface.mid"
152
+ mid = events_to_midi(output_events)
153
+ mid.save(output_midi)
154
+
155
+ print("βœ… MIDI generation successful.")
156
+
157
+ return output_midi, None
158
 
159
 
160
  def process_fn(input_midi, model_choice, selected_midi_program, history_length):