dieumercimvemba commited on
Commit
237ade0
·
verified ·
1 Parent(s): 1c0a33d

Update data/generate_audio.py

Browse files
Files changed (1) hide show
  1. data/generate_audio.py +15 -14
data/generate_audio.py CHANGED
@@ -1,19 +1,20 @@
1
- # Fichier: /data/generate_audio.py (FIX SÉCURITÉ PYTORCH 2.6 + TACOTRON2 MAI)
2
  import sys
3
  import os
4
  import json
5
  import wave
6
- import torch # <--- IMPORTANT : On importe torch
7
-
8
- # --- FIX SÉCURITÉ PYTORCH ---
9
- # On autorise explicitement le composant RAdam qui bloque le chargement
10
- try:
11
- from TTS.utils.radam import RAdam
12
- if hasattr(torch.serialization, 'add_safe_globals'):
13
- torch.serialization.add_safe_globals([RAdam])
14
- except Exception:
15
- # Si l'import direct échoue, on utilise une méthode plus permissive pour PyTorch 2.6+
16
- pass
 
17
 
18
  from TTS.api import TTS
19
 
@@ -25,9 +26,9 @@ text = sys.argv[1].replace('"', '')
25
  output_file = sys.argv[2]
26
 
27
  try:
28
- print("Tentative de chargement : tts_models/fr/mai/tacotron2-DDC", file=sys.stderr)
29
 
30
- # On charge le modèle
31
  tts = TTS(model_name="tts_models/fr/mai/tacotron2-DDC", progress_bar=False, gpu=False)
32
 
33
  print(f"Génération en cours...", file=sys.stderr)
 
1
+ # Fichier: /data/generate_audio.py (VERSION FORCE-LOAD STABLE)
2
  import sys
3
  import os
4
  import json
5
  import wave
6
+ import collections
7
+ import torch
8
+
9
+ # --- LE CORRECTIF DÉFINITIF POUR PYTORCH 2.6+ ---
10
+ # Cette fonction force PyTorch à ignorer la nouvelle restriction "weights_only"
11
+ # qui bloque les anciens modèles Coqui TTS.
12
+ original_load = torch.load
13
+ def patched_load(*args, **kwargs):
14
+ kwargs['weights_only'] = False
15
+ return original_load(*args, **kwargs)
16
+ torch.load = patched_load
17
+ # -----------------------------------------------
18
 
19
  from TTS.api import TTS
20
 
 
26
  output_file = sys.argv[2]
27
 
28
  try:
29
+ print("Tentative de chargement (Mode Forcé) : tts_models/fr/mai/tacotron2-DDC", file=sys.stderr)
30
 
31
+ # Initialisation du modèle
32
  tts = TTS(model_name="tts_models/fr/mai/tacotron2-DDC", progress_bar=False, gpu=False)
33
 
34
  print(f"Génération en cours...", file=sys.stderr)