lllindsey0615 commited on
Commit
0f15908
Β·
1 Parent(s): e93d344

debug for amt model loading with safetensors

Browse files
Files changed (1) hide show
  1. app.py +37 -20
app.py CHANGED
@@ -28,43 +28,61 @@ model_card = ModelCard(
28
 
29
  model_cache = {}
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def load_amt_model(model_choice):
32
  """Loads and caches the AMT model inside the worker process."""
33
  if model_choice in model_cache:
34
- print(f"βœ… Model {model_choice} loaded from cache")
35
  return model_cache[model_choice]
36
 
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- cache_dir = "./models"
39
 
40
  try:
41
- print(f"πŸ“Œ Checking available model files for: {model_choice}")
42
 
43
- # Detect model file format
44
- model_path = os.path.join(cache_dir, model_choice)
45
- bin_path = os.path.join(model_path, "pytorch_model.bin")
46
- safetensor_path = os.path.join(model_path, "model.safetensors")
47
 
48
- if os.path.exists(bin_path):
49
- print(f"βœ… Detected pytorch_model.bin for {model_choice}. Loading normally...")
50
- model = AutoModelForCausalLM.from_pretrained(model_choice, cache_dir=cache_dir).to(device)
51
 
52
- elif os.path.exists(safetensor_path):
53
- print(f"βœ… Detected model.safetensors for {model_choice}. Loading using SafeTensors...")
54
-
55
- # Load model configuration
56
- config = AutoConfig.from_pretrained(model_choice, cache_dir=cache_dir)
57
 
58
  # Load SafeTensors manually
59
- state_dict = load_file(safetensor_path)
 
 
 
 
 
60
  model = AutoModelForCausalLM.from_config(config) # Initialize model
61
  model.load_state_dict(state_dict) # Load weights
62
  model.to(device)
 
63
 
64
  else:
65
- raise ValueError(f"❌ No valid model file found for {model_choice}")
66
-
67
- print(f"βœ… Successfully loaded model: {model_choice}")
68
 
69
  except Exception as e:
70
  print(f"❌ Error loading model {model_choice}: {e}")
@@ -75,7 +93,6 @@ def load_amt_model(model_choice):
75
  return model
76
 
77
 
78
-
79
  @spaces.GPU
80
  def generate_accompaniment(midi_file, model_choice, selected_midi_program, history_length):
81
  """Generates accompaniment for the entire MIDI input, conditioned on the user-selected history length."""
 
28
 
29
  model_cache = {}
30
 
31
+ '''
32
+ def load_amt_model(model_choice):
33
+ """Loads and caches the AMT model inside the worker process."""
34
+ if model_choice in model_cache:
35
+ return model_cache[model_choice]
36
+
37
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
+ model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
39
+
40
+ model_cache[model_choice] = model
41
+ return model
42
+ '''
43
+
44
+ import os
45
+ import traceback
46
+ from transformers import AutoModelForCausalLM, AutoConfig
47
+ import torch
48
+ from safetensors.torch import load_file
49
+
50
  def load_amt_model(model_choice):
51
  """Loads and caches the AMT model inside the worker process."""
52
  if model_choice in model_cache:
 
53
  return model_cache[model_choice]
54
 
55
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
56
 
57
  try:
58
+ print(f"πŸ“Œ Loading model: {model_choice}")
59
 
60
+ # For small and medium models, use the original loading method
61
+ if model_choice in ["stanford-crfm/music-small-800k", "stanford-crfm/music-medium-800k"]:
62
+ model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
63
+ print(f"βœ… Successfully loaded {model_choice} using standard PyTorch method.")
64
 
65
+ # For large model, use SafeTensors method
66
+ elif model_choice == "stanford-crfm/music-large-800k":
67
+ print(f"πŸ“Œ Detected SafeTensors format for {model_choice}, loading manually...")
68
 
69
+ # Load model config
70
+ config = AutoConfig.from_pretrained(model_choice)
 
 
 
71
 
72
  # Load SafeTensors manually
73
+ safetensor_path = os.path.join("models", model_choice, "model.safetensors")
74
+
75
+ if not os.path.exists(safetensor_path):
76
+ raise FileNotFoundError(f"❌ SafeTensors file not found: {safetensor_path}")
77
+
78
+ state_dict = load_file(safetensor_path)
79
  model = AutoModelForCausalLM.from_config(config) # Initialize model
80
  model.load_state_dict(state_dict) # Load weights
81
  model.to(device)
82
+ print(f"βœ… Successfully loaded {model_choice} using SafeTensors.")
83
 
84
  else:
85
+ raise ValueError(f"❌ Unknown model choice: {model_choice}")
 
 
86
 
87
  except Exception as e:
88
  print(f"❌ Error loading model {model_choice}: {e}")
 
93
  return model
94
 
95
 
 
96
  @spaces.GPU
97
  def generate_accompaniment(midi_file, model_choice, selected_midi_program, history_length):
98
  """Generates accompaniment for the entire MIDI input, conditioned on the user-selected history length."""