kyutai-tts-handler / handler.py
daiemon12's picture
Update handler.py
43626b5 verified
"""
Handler final pour Kyutai TTS - Compatible HF Endpoints
"""
import torch
import base64
import io
import numpy as np
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
"""
Initialise le handler avec un fallback audio simple
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 Initialisation sur {self.device}")
# Configuration
self.sample_rate = 24000
self.model_loaded = False
try:
# Essayer de charger moshi
from moshi.models import loaders
print("📥 Tentative de chargement avec moshi...")
self.lm_model = loaders.get_pretrained_lm_model(
device=self.device,
repo_id="kyutai/tts-1.6b-en_fr"
)
self.model_loaded = True
print("✅ Modèle Kyutai chargé avec succès!")
except Exception as e:
print(f"⚠️ Impossible de charger Kyutai TTS: {e}")
print("🔄 Mode fallback activé - génération audio basique")
self.model_loaded = False
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Traite les requêtes TTS
"""
text = data.get("inputs", "")
if not text:
raise ValueError("Le paramètre 'inputs' est requis")
params = data.get("parameters", {})
language = params.get("language", "auto")
# Détection de la langue
if language == "auto":
fr_chars = set("àâäéèêëïîôùûçœ")
has_french = any(c in text.lower() for c in fr_chars)
language = "fr" if has_french else "en"
try:
if self.model_loaded:
# Utiliser le vrai modèle Kyutai
print(f"🎤 Synthèse Kyutai TTS: {len(text)} caractères en {language}")
with torch.no_grad():
audio_tensor = self.lm_model.synthesize(
text=text,
language=language,
speaker_id=0,
speed=1.0
)
audio_np = audio_tensor.cpu().numpy()
else:
# Fallback: générer un placeholder audio
print(f"🎵 Mode fallback: génération audio simple pour {len(text)} caractères")
duration = min(len(text) * 0.06, 10.0) # ~60ms par caractère, max 10s
samples = int(self.sample_rate * duration)
# Générer une voix synthétique simple
t = np.linspace(0, duration, samples)
# Fréquences pour simuler une voix
f1 = 200 + 50 * np.sin(2 * np.pi * 3 * t) # Modulation lente
f2 = 400 + 100 * np.sin(2 * np.pi * 2 * t)
# Combiner plusieurs harmoniques
audio_np = 0.3 * np.sin(2 * np.pi * f1 * t)
audio_np += 0.2 * np.sin(2 * np.pi * f2 * t)
audio_np += 0.1 * np.sin(2 * np.pi * 800 * t)
# Enveloppe pour rendre plus naturel
envelope = np.exp(-t / duration * 2)
audio_np *= envelope
# Normaliser l'audio
if np.max(np.abs(audio_np)) > 0:
audio_np = audio_np / np.max(np.abs(audio_np)) * 0.8
# Convertir en WAV
audio_bytes = self.numpy_to_wav(audio_np, self.sample_rate)
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
return {
"audio": audio_base64,
"sampling_rate": self.sample_rate,
"duration": len(audio_np) / self.sample_rate,
"model_loaded": self.model_loaded,
"language": language
}
except Exception as e:
print(f"❌ Erreur lors de la synthèse: {str(e)}")
# Retourner un court silence en cas d'erreur
silence = np.zeros(int(self.sample_rate * 0.5)) # 0.5s de silence
audio_bytes = self.numpy_to_wav(silence, self.sample_rate)
return {
"audio": base64.b64encode(audio_bytes).decode('utf-8'),
"sampling_rate": self.sample_rate,
"duration": 0.5,
"error": str(e),
"model_loaded": self.model_loaded
}
def numpy_to_wav(self, audio_np, sample_rate):
"""Convertit numpy array en WAV bytes"""
import struct
# Ensure audio is 1D
if audio_np.ndim > 1:
audio_np = audio_np.flatten()
# Convert to int16
audio_int16 = (audio_np * 32767).astype(np.int16)
# Create WAV header
num_samples = len(audio_int16)
num_channels = 1
bits_per_sample = 16
byte_rate = sample_rate * num_channels * bits_per_sample // 8
block_align = num_channels * bits_per_sample // 8
# WAV file structure
wav_header = struct.pack(
'<4sI4s4sIHHIIHH4sI',
b'RIFF',
36 + num_samples * 2, # ChunkSize
b'WAVE',
b'fmt ',
16, # Subchunk1Size (PCM)
1, # AudioFormat (PCM)
num_channels,
sample_rate,
byte_rate,
block_align,
bits_per_sample,
b'data',
num_samples * 2 # Subchunk2Size
)
# Combine header and audio data
wav_data = wav_header + audio_int16.tobytes()
return wav_data