ksangk commited on
Commit
771c988
·
1 Parent(s): e473d08

switch to safetensors

Browse files
Files changed (2) hide show
  1. app.py +4 -3
  2. chord/io.py +21 -1
app.py CHANGED
@@ -13,6 +13,7 @@ import spaces
13
  from chord import ChordModel
14
  from chord.module import make
15
  from chord.util import get_positions, rgb_to_srgb
 
16
 
17
  EXAMPLES_USECASE_1 = [
18
  [f"examples/generated/{f}"]
@@ -29,14 +30,14 @@ EXAMPLES_USECASE_3 = [
29
 
30
  MODEL_OBJ = None
31
  login(token=os.environ["HF_TOKEN"])
32
- MODEL_CKPT_PATH = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.ckpt")
33
  def load_model(ckpt_path):
34
  print("Loading model from:", ckpt_path)
35
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
  config = OmegaConf.load("config/chord.yaml")
37
  model = ChordModel(config)
38
- ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
39
- model.load_state_dict(ckpt["state_dict"])
40
  model.eval()
41
  model.to(device)
42
  return model
 
13
  from chord import ChordModel
14
  from chord.module import make
15
  from chord.util import get_positions, rgb_to_srgb
16
+ from chord.io import load_torch_file
17
 
18
  EXAMPLES_USECASE_1 = [
19
  [f"examples/generated/{f}"]
 
30
 
31
  MODEL_OBJ = None
32
  login(token=os.environ["HF_TOKEN"])
33
+ MODEL_CKPT_PATH = hf_hub_download(repo_id="Ubisoft/ubisoft-laforge-chord", filename="chord_v1.safetensors")
34
  def load_model(ckpt_path):
35
  print("Loading model from:", ckpt_path)
36
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  config = OmegaConf.load("config/chord.yaml")
38
  model = ChordModel(config)
39
+ state_dict = load_torch_file(ckpt_path)
40
+ model.load_state_dict(state_dict)
41
  model.eval()
42
  model.to(device)
43
  return model
chord/io.py CHANGED
@@ -3,6 +3,7 @@ import imageio.v3 as imageio
3
  import numpy as np
4
  import warnings
5
  import os
 
6
 
7
  import torchvision.transforms.functional as F
8
 
@@ -77,4 +78,23 @@ def save_maps(path: str, maps: dict):
77
  os.makedirs(path)
78
  for name, image in maps.items():
79
  out_img = create_img(image)
80
- out_img.save(os.path.join(path, name+".png"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  import warnings
5
  import os
6
+ import safetensors
7
 
8
  import torchvision.transforms.functional as F
9
 
 
78
  os.makedirs(path)
79
  for name, image in maps.items():
80
  out_img = create_img(image)
81
+ out_img.save(os.path.join(path, name+".png"))
82
+
83
+ def load_torch_file(ckpt, device=None):
84
+ if device is None:
85
+ device = torch.device("cpu")
86
+ if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
87
+ with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
88
+ state_dict = {}
89
+ for k in f.keys():
90
+ tensor = f.get_tensor(k)
91
+ state_dict[k] = tensor
92
+ else:
93
+ torch_args = {}
94
+ ckpt = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
95
+
96
+ if "state_dict" in ckpt:
97
+ state_dict = ckpt["state_dict"]
98
+ else:
99
+ state_dict = ckpt
100
+ return state_dict