Spaces:
Sleeping
Sleeping
fix: load T3 model state with device mapping
Browse files
chatterbox/src/chatterbox/tts.py
CHANGED
|
@@ -133,7 +133,7 @@ class ChatterboxTTS:
|
|
| 133 |
ve.to(device).eval()
|
| 134 |
|
| 135 |
t3 = T3()
|
| 136 |
-
t3_state = torch.load(ckpt_dir / "t3_cfg.pt")
|
| 137 |
if "model" in t3_state.keys():
|
| 138 |
t3_state = t3_state["model"][0]
|
| 139 |
t3.load_state_dict(t3_state)
|
|
@@ -141,7 +141,7 @@ class ChatterboxTTS:
|
|
| 141 |
|
| 142 |
s3gen = S3Gen()
|
| 143 |
s3gen.load_state_dict(
|
| 144 |
-
torch.load(ckpt_dir / "s3gen.pt")
|
| 145 |
)
|
| 146 |
s3gen.to(device).eval()
|
| 147 |
|
|
|
|
| 133 |
ve.to(device).eval()
|
| 134 |
|
| 135 |
t3 = T3()
|
| 136 |
+
t3_state = torch.load(ckpt_dir / "t3_cfg.pt", map_location=torch.device(device))
|
| 137 |
if "model" in t3_state.keys():
|
| 138 |
t3_state = t3_state["model"][0]
|
| 139 |
t3.load_state_dict(t3_state)
|
|
|
|
| 141 |
|
| 142 |
s3gen = S3Gen()
|
| 143 |
s3gen.load_state_dict(
|
| 144 |
+
torch.load(ckpt_dir / "s3gen.pt", map_location=torch.device(device))
|
| 145 |
)
|
| 146 |
s3gen.to(device).eval()
|
| 147 |
|