File size: 1,279 Bytes
97e3499 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 | """
Patch Coqui TTS for PyTorch 2.11+ compatibility.
PyTorch 2.11 changed torch.load() default from weights_only=False to True.
XTTS-v2 checkpoints contain config objects that require weights_only=False.
Usage:
conda activate new-arabic-tts
python scripts/patch_tts.py
"""
import site
import os
import sys
def patch():
site_packages = site.getsitepackages()[0]
io_path = os.path.join(site_packages, "TTS", "utils", "io.py")
if not os.path.exists(io_path):
print(f"ERROR: {io_path} not found. Is Coqui TTS installed?")
sys.exit(1)
with open(io_path, "r") as f:
content = f.read()
if "weights_only=False" in content:
print("Already patched.")
return
patched = content.replace(
"return torch.load(f, map_location=map_location, **kwargs)",
"return torch.load(f, map_location=map_location, weights_only=False, **kwargs)",
)
count = content.count("return torch.load(f, map_location=map_location, **kwargs)")
if count == 0:
print("ERROR: Could not find torch.load pattern to patch.")
sys.exit(1)
with open(io_path, "w") as f:
f.write(patched)
print(f"Patched {count} torch.load() calls in {io_path}")
if __name__ == "__main__":
patch()
|