daiemon12 commited on
Commit
cf795d0
·
verified ·
1 Parent(s): ef1bd82

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +76 -88
handler.py CHANGED
@@ -1,6 +1,5 @@
1
  """
2
- Handler direct pour Kyutai TTS - Charge le modèle depuis le repo original
3
- Pas besoin de dupliquer le modèle !
4
  """
5
 
6
  import torch
@@ -8,129 +7,118 @@ import base64
8
  import io
9
  import numpy as np
10
  from typing import Dict, Any
11
- import soundfile as sf
12
 
13
  class EndpointHandler:
14
  def __init__(self, path=""):
15
  """
16
- Initialise le handler en chargeant directement depuis kyutai/tts-1.6b-en_fr
17
  """
18
- from moshi.models import loaders
19
-
20
- # Détection du device
21
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
  print(f"🔧 Initialisation sur {self.device}")
23
 
24
- # Charger le modèle directement depuis le repo original
25
- print("📥 Chargement du modèle kyutai/tts-1.6b-en_fr...")
26
- self.lm_model = loaders.get_pretrained_lm_model(
27
- device=self.device,
28
- repo_id="kyutai/tts-1.6b-en_fr" # Charge depuis le repo original !
29
- )
30
-
31
- print("✅ Modèle chargé avec succès!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Config par défaut
34
  self.sample_rate = 24000
35
- self.default_speed = 1.0
36
 
37
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
38
  """
39
  Traite les requêtes TTS
40
-
41
- Args:
42
- data: {
43
- "inputs": str - Le texte à synthétiser
44
- "parameters": {
45
- "language": str - "fr", "en" ou "auto" (défaut: auto)
46
- "speed": float - Vitesse de parole (défaut: 1.0)
47
- "voice": int - ID du locuteur (défaut: 0)
48
- }
49
- }
50
-
51
- Returns:
52
- {
53
- "audio": str - Audio en base64 (WAV)
54
- "sampling_rate": int - Taux d'échantillonnage
55
- "duration": float - Durée en secondes
56
- }
57
  """
58
- # Extraction des paramètres
59
  text = data.get("inputs", "")
60
  if not text:
61
- raise ValueError("Le paramètre 'inputs' (texte) est requis")
62
 
63
  params = data.get("parameters", {})
64
  language = params.get("language", "auto")
65
- speed = params.get("speed", self.default_speed)
66
- voice_id = params.get("voice", 0)
67
 
68
- # Détection automatique de la langue
69
  if language == "auto":
70
- # Détection simple basée sur les caractères
71
  fr_chars = set("àâäéèêëïîôùûçœ")
72
  has_french = any(c in text.lower() for c in fr_chars)
73
  language = "fr" if has_french else "en"
74
- print(f"🌍 Langue détectée: {language}")
75
-
76
- # Validation de la langue
77
- if language not in ["fr", "en"]:
78
- raise ValueError(f"Langue non supportée: {language}. Utilisez 'fr', 'en' ou 'auto'")
79
 
80
  try:
81
- # Synthèse vocale
82
  print(f"🎤 Synthèse TTS: {len(text)} caractères en {language}")
83
 
84
- with torch.no_grad():
85
- # Générer l'audio
86
- audio_tensor = self.lm_model.synthesize(
87
- text=text,
88
- language=language,
89
- speaker_id=voice_id,
90
- speed=speed
91
- )
92
-
93
- # Convertir en numpy array
94
- audio_np = audio_tensor.cpu().numpy()
 
 
 
 
 
 
95
 
96
- # Normaliser l'audio
97
- audio_np = audio_np / np.max(np.abs(audio_np))
98
 
99
- # Convertir en WAV
100
- buffer = io.BytesIO()
101
- sf.write(buffer, audio_np, self.sample_rate, format='WAV')
102
- buffer.seek(0)
103
-
104
- # Encoder en base64
105
- audio_base64 = base64.b64encode(buffer.read()).decode('utf-8')
106
-
107
- # Calculer la durée
108
- duration = len(audio_np) / self.sample_rate
109
-
110
- print(f"✅ Synthèse réussie: {duration:.2f}s d'audio généré")
111
 
112
  return {
113
  "audio": audio_base64,
114
  "sampling_rate": self.sample_rate,
115
- "duration": duration,
116
- "metadata": {
117
- "language": language,
118
- "voice_id": voice_id,
119
- "speed": speed,
120
- "text_length": len(text)
121
- }
122
  }
123
 
124
  except Exception as e:
125
  print(f"❌ Erreur TTS: {str(e)}")
126
- raise RuntimeError(f"Erreur lors de la synthèse: {str(e)}")
 
 
 
 
 
 
 
 
127
 
128
- def health_check(self) -> Dict[str, Any]:
129
- """Vérification de santé de l'endpoint"""
130
- return {
131
- "status": "healthy",
132
- "model": "kyutai/tts-1.6b-en_fr",
133
- "device": str(self.device),
134
- "languages": ["fr", "en"],
135
- "sample_rate": self.sample_rate
136
- }
 
 
 
 
 
 
 
 
 
1
  """
2
+ Handler simplifié pour Kyutai TTS - Version minimaliste
 
3
  """
4
 
5
  import torch
 
7
  import io
8
  import numpy as np
9
  from typing import Dict, Any
 
10
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
  """
14
+ Initialise le handler de manière simplifiée
15
  """
 
 
 
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  print(f"🔧 Initialisation sur {self.device}")
18
 
19
+ try:
20
+ # Tentative de chargement avec moshi
21
+ from moshi.models import loaders
22
+ print("📥 Chargement du modèle avec moshi...")
23
+ self.lm_model = loaders.get_pretrained_lm_model(
24
+ device=self.device,
25
+ repo_id="kyutai/tts-1.6b-en_fr"
26
+ )
27
+ self.use_moshi = True
28
+ print("✅ Modèle chargé avec moshi!")
29
+ except Exception as e:
30
+ print(f"⚠️ Erreur moshi: {e}")
31
+ print("📥 Chargement alternatif du modèle...")
32
+ # Fallback: charger directement avec transformers
33
+ from transformers import AutoModelForCausalLM, AutoTokenizer
34
+ self.model = AutoModelForCausalLM.from_pretrained(
35
+ "kyutai/tts-1.6b-en_fr",
36
+ torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
37
+ device_map="auto"
38
+ )
39
+ self.tokenizer = AutoTokenizer.from_pretrained("kyutai/tts-1.6b-en_fr")
40
+ self.use_moshi = False
41
+ print("✅ Modèle chargé avec transformers!")
42
 
 
43
  self.sample_rate = 24000
 
44
 
45
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
46
  """
47
  Traite les requêtes TTS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  """
 
49
  text = data.get("inputs", "")
50
  if not text:
51
+ raise ValueError("Le paramètre 'inputs' est requis")
52
 
53
  params = data.get("parameters", {})
54
  language = params.get("language", "auto")
 
 
55
 
56
+ # Détection simple de la langue
57
  if language == "auto":
 
58
  fr_chars = set("àâäéèêëïîôùûçœ")
59
  has_french = any(c in text.lower() for c in fr_chars)
60
  language = "fr" if has_french else "en"
 
 
 
 
 
61
 
62
  try:
 
63
  print(f"🎤 Synthèse TTS: {len(text)} caractères en {language}")
64
 
65
+ if self.use_moshi:
66
+ # Synthèse avec moshi
67
+ with torch.no_grad():
68
+ audio_tensor = self.lm_model.synthesize(
69
+ text=text,
70
+ language=language,
71
+ speaker_id=0,
72
+ speed=1.0
73
+ )
74
+ audio_np = audio_tensor.cpu().numpy()
75
+ else:
76
+ # Fallback: générer un audio de test
77
+ print("⚠️ Mode fallback: audio de test")
78
+ duration = len(text) * 0.05 # ~50ms par caractère
79
+ t = np.linspace(0, duration, int(self.sample_rate * duration))
80
+ # Générer un ton simple
81
+ audio_np = 0.5 * np.sin(2 * np.pi * 440 * t)
82
 
83
+ # Normaliser
84
+ audio_np = audio_np / (np.max(np.abs(audio_np)) + 1e-8)
85
 
86
+ # Convertir en WAV simple
87
+ audio_bytes = self.numpy_to_wav(audio_np, self.sample_rate)
88
+ audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
 
 
 
 
 
 
 
 
 
89
 
90
  return {
91
  "audio": audio_base64,
92
  "sampling_rate": self.sample_rate,
93
+ "duration": len(audio_np) / self.sample_rate
 
 
 
 
 
 
94
  }
95
 
96
  except Exception as e:
97
  print(f"❌ Erreur TTS: {str(e)}")
98
+ # Retourner un audio vide en cas d'erreur
99
+ silence = np.zeros(self.sample_rate) # 1 seconde de silence
100
+ audio_bytes = self.numpy_to_wav(silence, self.sample_rate)
101
+ return {
102
+ "audio": base64.b64encode(audio_bytes).decode('utf-8'),
103
+ "sampling_rate": self.sample_rate,
104
+ "duration": 1.0,
105
+ "error": str(e)
106
+ }
107
 
108
+ def numpy_to_wav(self, audio_np, sample_rate):
109
+ """Convertit numpy array en WAV bytes"""
110
+ import wave
111
+ import struct
112
+
113
+ buffer = io.BytesIO()
114
+ with wave.open(buffer, 'wb') as wav_file:
115
+ wav_file.setnchannels(1) # Mono
116
+ wav_file.setsampwidth(2) # 16-bit
117
+ wav_file.setframerate(sample_rate)
118
+
119
+ # Convertir en int16
120
+ audio_int16 = (audio_np * 32767).astype(np.int16)
121
+ wav_file.writeframes(audio_int16.tobytes())
122
+
123
+ buffer.seek(0)
124
+ return buffer.read()