lllindsey0615 commited on
Commit
1f73822
·
1 Parent(s): f4d3324

enable load large checkpoints using safetensors

Browse files
Files changed (1) hide show
  1. app.py +32 -0
app.py CHANGED
@@ -7,6 +7,9 @@ from anticipation import ops
7
  from anticipation.tokenize import extract_instruments
8
  import torch
9
  from pyharp import *
 
 
 
10
 
11
  #Model Choices
12
  SMALL_MODEL = "stanford-crfm/music-small-800k"
@@ -25,6 +28,7 @@ model_card = ModelCard(
25
 
26
  model_cache = {}
27
 
 
28
  def load_amt_model(model_choice):
29
  """Loads and caches the AMT model inside the worker process."""
30
  if model_choice in model_cache:
@@ -35,6 +39,34 @@ def load_amt_model(model_choice):
35
 
36
  model_cache[model_choice] = model
37
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
 
 
7
  from anticipation.tokenize import extract_instruments
8
  import torch
9
  from pyharp import *
10
+ from safetensors.torch import load_file
11
+ import os
12
+
13
 
14
  #Model Choices
15
  SMALL_MODEL = "stanford-crfm/music-small-800k"
 
28
 
29
  model_cache = {}
30
 
31
+ '''
32
  def load_amt_model(model_choice):
33
  """Loads and caches the AMT model inside the worker process."""
34
  if model_choice in model_cache:
 
39
 
40
  model_cache[model_choice] = model
41
  return model
42
+ '''
43
+
44
+ def load_amt_model(model_choice):
45
+ """Loads and caches the AMT model inside the worker process."""
46
+ if model_choice in model_cache:
47
+ return model_cache[model_choice]
48
+
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ if model_choice == LARGE_MODEL:
52
+ # Large model uses safetensors
53
+ model_dir = "./tmp_music_large"
54
+ os.makedirs(model_dir, exist_ok=True)
55
+
56
+ print(f"Loading {LARGE_MODEL} from safetensors format...")
57
+ model = AutoModelForCausalLM.from_pretrained(
58
+ LARGE_MODEL,
59
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
60
+ low_cpu_mem_usage=True
61
+ ).to(device)
62
+ else:
63
+ # Small and medium use standard PyTorch .bin format
64
+ print(f"Loading {model_choice} from standard format...")
65
+ model = AutoModelForCausalLM.from_pretrained(model_choice).to(device)
66
+
67
+ model_cache[model_choice] = model
68
+ return model
69
+
70
 
71
 
72