daiemon12 commited on
Commit
43626b5
·
verified ·
1 Parent(s): 3e75530

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +84 -52
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Handler simplifié pour Kyutai TTS - Version minimaliste
3
  """
4
 
5
  import torch
@@ -11,36 +11,29 @@ from typing import Dict, Any
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
  """
14
- Initialise le handler de manière simplifiée
15
  """
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  print(f"🔧 Initialisation sur {self.device}")
18
 
 
 
 
 
19
  try:
20
- # Tentative de chargement avec moshi
21
  from moshi.models import loaders
22
- print("📥 Chargement du modèle avec moshi...")
23
  self.lm_model = loaders.get_pretrained_lm_model(
24
  device=self.device,
25
  repo_id="kyutai/tts-1.6b-en_fr"
26
  )
27
- self.use_moshi = True
28
- print("✅ Modèle chargé avec moshi!")
29
  except Exception as e:
30
- print(f"⚠️ Erreur moshi: {e}")
31
- print("📥 Chargement alternatif du modèle...")
32
- # Fallback: charger directement avec transformers
33
- from transformers import AutoModelForCausalLM, AutoTokenizer
34
- self.model = AutoModelForCausalLM.from_pretrained(
35
- "kyutai/tts-1.6b-en_fr",
36
- torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
37
- device_map="auto"
38
- )
39
- self.tokenizer = AutoTokenizer.from_pretrained("kyutai/tts-1.6b-en_fr")
40
- self.use_moshi = False
41
- print("✅ Modèle chargé avec transformers!")
42
-
43
- self.sample_rate = 24000
44
 
45
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
46
  """
@@ -53,17 +46,16 @@ class EndpointHandler:
53
  params = data.get("parameters", {})
54
  language = params.get("language", "auto")
55
 
56
- # Détection simple de la langue
57
  if language == "auto":
58
  fr_chars = set("àâäéèêëïîôùûçœ")
59
  has_french = any(c in text.lower() for c in fr_chars)
60
  language = "fr" if has_french else "en"
61
 
62
  try:
63
- print(f"🎤 Synthèse TTS: {len(text)} caractères en {language}")
64
-
65
- if self.use_moshi:
66
- # Synthèse avec moshi
67
  with torch.no_grad():
68
  audio_tensor = self.lm_model.synthesize(
69
  text=text,
@@ -73,52 +65,92 @@ class EndpointHandler:
73
  )
74
  audio_np = audio_tensor.cpu().numpy()
75
  else:
76
- # Fallback: générer un audio de test
77
- print("⚠️ Mode fallback: audio de test")
78
- duration = len(text) * 0.05 # ~50ms par caractère
79
- t = np.linspace(0, duration, int(self.sample_rate * duration))
80
- # Générer un ton simple
81
- audio_np = 0.5 * np.sin(2 * np.pi * 440 * t)
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- # Normaliser
84
- audio_np = audio_np / (np.max(np.abs(audio_np)) + 1e-8)
 
85
 
86
- # Convertir en WAV simple
87
  audio_bytes = self.numpy_to_wav(audio_np, self.sample_rate)
88
  audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
89
 
90
  return {
91
  "audio": audio_base64,
92
  "sampling_rate": self.sample_rate,
93
- "duration": len(audio_np) / self.sample_rate
 
 
94
  }
95
 
96
  except Exception as e:
97
- print(f"❌ Erreur TTS: {str(e)}")
98
- # Retourner un audio vide en cas d'erreur
99
- silence = np.zeros(self.sample_rate) # 1 seconde de silence
100
  audio_bytes = self.numpy_to_wav(silence, self.sample_rate)
101
  return {
102
  "audio": base64.b64encode(audio_bytes).decode('utf-8'),
103
  "sampling_rate": self.sample_rate,
104
- "duration": 1.0,
105
- "error": str(e)
 
106
  }
107
 
108
  def numpy_to_wav(self, audio_np, sample_rate):
109
  """Convertit numpy array en WAV bytes"""
110
- import wave
111
  import struct
112
 
113
- buffer = io.BytesIO()
114
- with wave.open(buffer, 'wb') as wav_file:
115
- wav_file.setnchannels(1) # Mono
116
- wav_file.setsampwidth(2) # 16-bit
117
- wav_file.setframerate(sample_rate)
118
-
119
- # Convertir en int16
120
- audio_int16 = (audio_np * 32767).astype(np.int16)
121
- wav_file.writeframes(audio_int16.tobytes())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- buffer.seek(0)
124
- return buffer.read()
 
1
  """
2
+ Handler final pour Kyutai TTS - Compatible HF Endpoints
3
  """
4
 
5
  import torch
 
11
  class EndpointHandler:
12
  def __init__(self, path=""):
13
  """
14
+ Initialise le handler avec un fallback audio simple
15
  """
16
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  print(f"🔧 Initialisation sur {self.device}")
18
 
19
+ # Configuration
20
+ self.sample_rate = 24000
21
+ self.model_loaded = False
22
+
23
  try:
24
+ # Essayer de charger moshi
25
  from moshi.models import loaders
26
+ print("📥 Tentative de chargement avec moshi...")
27
  self.lm_model = loaders.get_pretrained_lm_model(
28
  device=self.device,
29
  repo_id="kyutai/tts-1.6b-en_fr"
30
  )
31
+ self.model_loaded = True
32
+ print("✅ Modèle Kyutai chargé avec succès!")
33
  except Exception as e:
34
+ print(f"⚠️ Impossible de charger Kyutai TTS: {e}")
35
+ print("🔄 Mode fallback activé - génération audio basique")
36
+ self.model_loaded = False
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
39
  """
 
46
  params = data.get("parameters", {})
47
  language = params.get("language", "auto")
48
 
49
+ # Détection de la langue
50
  if language == "auto":
51
  fr_chars = set("àâäéèêëïîôùûçœ")
52
  has_french = any(c in text.lower() for c in fr_chars)
53
  language = "fr" if has_french else "en"
54
 
55
  try:
56
+ if self.model_loaded:
57
+ # Utiliser le vrai modèle Kyutai
58
+ print(f"🎤 Synthèse Kyutai TTS: {len(text)} caractères en {language}")
 
59
  with torch.no_grad():
60
  audio_tensor = self.lm_model.synthesize(
61
  text=text,
 
65
  )
66
  audio_np = audio_tensor.cpu().numpy()
67
  else:
68
+ # Fallback: générer un placeholder audio
69
+ print(f"🎵 Mode fallback: génération audio simple pour {len(text)} caractères")
70
+ duration = min(len(text) * 0.06, 10.0) # ~60ms par caractère, max 10s
71
+ samples = int(self.sample_rate * duration)
72
+
73
+ # Générer une voix synthétique simple
74
+ t = np.linspace(0, duration, samples)
75
+ # Fréquences pour simuler une voix
76
+ f1 = 200 + 50 * np.sin(2 * np.pi * 3 * t) # Modulation lente
77
+ f2 = 400 + 100 * np.sin(2 * np.pi * 2 * t)
78
+
79
+ # Combiner plusieurs harmoniques
80
+ audio_np = 0.3 * np.sin(2 * np.pi * f1 * t)
81
+ audio_np += 0.2 * np.sin(2 * np.pi * f2 * t)
82
+ audio_np += 0.1 * np.sin(2 * np.pi * 800 * t)
83
+
84
+ # Enveloppe pour rendre plus naturel
85
+ envelope = np.exp(-t / duration * 2)
86
+ audio_np *= envelope
87
 
88
+ # Normaliser l'audio
89
+ if np.max(np.abs(audio_np)) > 0:
90
+ audio_np = audio_np / np.max(np.abs(audio_np)) * 0.8
91
 
92
+ # Convertir en WAV
93
  audio_bytes = self.numpy_to_wav(audio_np, self.sample_rate)
94
  audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
95
 
96
  return {
97
  "audio": audio_base64,
98
  "sampling_rate": self.sample_rate,
99
+ "duration": len(audio_np) / self.sample_rate,
100
+ "model_loaded": self.model_loaded,
101
+ "language": language
102
  }
103
 
104
  except Exception as e:
105
+ print(f"❌ Erreur lors de la synthèse: {str(e)}")
106
+ # Retourner un court silence en cas d'erreur
107
+ silence = np.zeros(int(self.sample_rate * 0.5)) # 0.5s de silence
108
  audio_bytes = self.numpy_to_wav(silence, self.sample_rate)
109
  return {
110
  "audio": base64.b64encode(audio_bytes).decode('utf-8'),
111
  "sampling_rate": self.sample_rate,
112
+ "duration": 0.5,
113
+ "error": str(e),
114
+ "model_loaded": self.model_loaded
115
  }
116
 
117
  def numpy_to_wav(self, audio_np, sample_rate):
118
  """Convertit numpy array en WAV bytes"""
 
119
  import struct
120
 
121
+ # Ensure audio is 1D
122
+ if audio_np.ndim > 1:
123
+ audio_np = audio_np.flatten()
124
+
125
+ # Convert to int16
126
+ audio_int16 = (audio_np * 32767).astype(np.int16)
127
+
128
+ # Create WAV header
129
+ num_samples = len(audio_int16)
130
+ num_channels = 1
131
+ bits_per_sample = 16
132
+ byte_rate = sample_rate * num_channels * bits_per_sample // 8
133
+ block_align = num_channels * bits_per_sample // 8
134
+
135
+ # WAV file structure
136
+ wav_header = struct.pack(
137
+ '<4sI4s4sIHHIIHH4sI',
138
+ b'RIFF',
139
+ 36 + num_samples * 2, # ChunkSize
140
+ b'WAVE',
141
+ b'fmt ',
142
+ 16, # Subchunk1Size (PCM)
143
+ 1, # AudioFormat (PCM)
144
+ num_channels,
145
+ sample_rate,
146
+ byte_rate,
147
+ block_align,
148
+ bits_per_sample,
149
+ b'data',
150
+ num_samples * 2 # Subchunk2Size
151
+ )
152
+
153
+ # Combine header and audio data
154
+ wav_data = wav_header + audio_int16.tobytes()
155
 
156
+ return wav_data