playmak3r commited on
Commit
57450e4
·
1 Parent(s): 55ff79e

fix: load T3 model state with device mapping

Browse files
Files changed (1) hide show
  1. chatterbox/src/chatterbox/tts.py +2 -2
chatterbox/src/chatterbox/tts.py CHANGED
@@ -133,7 +133,7 @@ class ChatterboxTTS:
133
  ve.to(device).eval()
134
 
135
  t3 = T3()
136
- t3_state = torch.load(ckpt_dir / "t3_cfg.pt")
137
  if "model" in t3_state.keys():
138
  t3_state = t3_state["model"][0]
139
  t3.load_state_dict(t3_state)
@@ -141,7 +141,7 @@ class ChatterboxTTS:
141
 
142
  s3gen = S3Gen()
143
  s3gen.load_state_dict(
144
- torch.load(ckpt_dir / "s3gen.pt")
145
  )
146
  s3gen.to(device).eval()
147
 
 
133
  ve.to(device).eval()
134
 
135
  t3 = T3()
136
+ t3_state = torch.load(ckpt_dir / "t3_cfg.pt", map_location=torch.device(device))
137
  if "model" in t3_state.keys():
138
  t3_state = t3_state["model"][0]
139
  t3.load_state_dict(t3_state)
 
141
 
142
  s3gen = S3Gen()
143
  s3gen.load_state_dict(
144
+ torch.load(ckpt_dir / "s3gen.pt", map_location=torch.device(device))
145
  )
146
  s3gen.to(device).eval()
147