patocolher commited on
Commit
365339b
·
verified ·
1 Parent(s): d3b67ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -2
app.py CHANGED
@@ -8,6 +8,44 @@ import time
8
  import torch
9
  import torchaudio
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  #download for mecab
13
  os.system('python -m unidic download')
@@ -29,8 +67,6 @@ from scipy.io.wavfile import write
29
  from pydub import AudioSegment
30
 
31
  from TTS.api import TTS
32
- from TTS.tts.configs.xtts_config import XttsConfig
33
- from TTS.tts.models.xtts import Xtts
34
  from TTS.utils.generic_utils import get_user_data_dir
35
 
36
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
8
  import torch
9
  import torchaudio
10
 
11
+ # ============================================================================
12
+ # Fix for PyTorch 2.6+ - Add safe globals for TTS classes
13
+ # PyTorch 2.6 changed the default value of `weights_only` from False to True
14
+ # This breaks loading TTS models that use custom classes
15
+ # ============================================================================
16
+ import torch.serialization
17
+
18
+ # First, try to import all necessary TTS classes
19
+ try:
20
+ from TTS.tts.configs.xtts_config import XttsConfig
21
+ from TTS.tts.models.xtts import Xtts
22
+ from TTS.config.shared_configs import BaseDatasetConfig
23
+
24
+ # Add all necessary TTS classes to safe globals
25
+ torch.serialization.add_safe_globals([
26
+ XttsConfig,
27
+ Xtts,
28
+ BaseDatasetConfig,
29
+ ])
30
+ print("Added TTS classes to torch safe globals")
31
+ except Exception as e:
32
+ print(f"Warning: Could not add safe globals: {e}")
33
+
34
+ # Fallback: Monkey patch torch.load to use weights_only=False for TTS models
35
+ # This is needed because TTS models may contain additional custom classes
36
+ _original_torch_load = torch.load
37
+
38
+ def _patched_torch_load(*args, **kwargs):
39
+ # For .pth files (model checkpoints), default to weights_only=False if not explicitly set
40
+ if 'weights_only' not in kwargs:
41
+ if len(args) > 0 and isinstance(args[0], str) and args[0].endswith('.pth'):
42
+ kwargs['weights_only'] = False
43
+ return _original_torch_load(*args, **kwargs)
44
+
45
+ torch.load = _patched_torch_load
46
+ print("Applied torch.load patch for PyTorch 2.6+ compatibility")
47
+ # ============================================================================
48
+
49
 
50
  #download for mecab
51
  os.system('python -m unidic download')
 
67
  from pydub import AudioSegment
68
 
69
  from TTS.api import TTS
 
 
70
  from TTS.utils.generic_utils import get_user_data_dir
71
 
72
  HF_TOKEN = os.environ.get("HF_TOKEN")