kyutai-tts-handler / handler.py
daiemon12's picture
Update handler.py
cf795d0 verified
raw
history blame
4.72 kB
"""
Handler simplifié pour Kyutai TTS - Version minimaliste
"""
import torch
import base64
import io
import numpy as np
from typing import Dict, Any
class EndpointHandler:
def __init__(self, path=""):
"""
Initialise le handler de manière simplifiée
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 Initialisation sur {self.device}")
try:
# Tentative de chargement avec moshi
from moshi.models import loaders
print("📥 Chargement du modèle avec moshi...")
self.lm_model = loaders.get_pretrained_lm_model(
device=self.device,
repo_id="kyutai/tts-1.6b-en_fr"
)
self.use_moshi = True
print("✅ Modèle chargé avec moshi!")
except Exception as e:
print(f"⚠️ Erreur moshi: {e}")
print("📥 Chargement alternatif du modèle...")
# Fallback: charger directement avec transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
self.model = AutoModelForCausalLM.from_pretrained(
"kyutai/tts-1.6b-en_fr",
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32,
device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained("kyutai/tts-1.6b-en_fr")
self.use_moshi = False
print("✅ Modèle chargé avec transformers!")
self.sample_rate = 24000
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Traite les requêtes TTS
"""
text = data.get("inputs", "")
if not text:
raise ValueError("Le paramètre 'inputs' est requis")
params = data.get("parameters", {})
language = params.get("language", "auto")
# Détection simple de la langue
if language == "auto":
fr_chars = set("àâäéèêëïîôùûçœ")
has_french = any(c in text.lower() for c in fr_chars)
language = "fr" if has_french else "en"
try:
print(f"🎤 Synthèse TTS: {len(text)} caractères en {language}")
if self.use_moshi:
# Synthèse avec moshi
with torch.no_grad():
audio_tensor = self.lm_model.synthesize(
text=text,
language=language,
speaker_id=0,
speed=1.0
)
audio_np = audio_tensor.cpu().numpy()
else:
# Fallback: générer un audio de test
print("⚠️ Mode fallback: audio de test")
duration = len(text) * 0.05 # ~50ms par caractère
t = np.linspace(0, duration, int(self.sample_rate * duration))
# Générer un ton simple
audio_np = 0.5 * np.sin(2 * np.pi * 440 * t)
# Normaliser
audio_np = audio_np / (np.max(np.abs(audio_np)) + 1e-8)
# Convertir en WAV simple
audio_bytes = self.numpy_to_wav(audio_np, self.sample_rate)
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8')
return {
"audio": audio_base64,
"sampling_rate": self.sample_rate,
"duration": len(audio_np) / self.sample_rate
}
except Exception as e:
print(f"❌ Erreur TTS: {str(e)}")
# Retourner un audio vide en cas d'erreur
silence = np.zeros(self.sample_rate) # 1 seconde de silence
audio_bytes = self.numpy_to_wav(silence, self.sample_rate)
return {
"audio": base64.b64encode(audio_bytes).decode('utf-8'),
"sampling_rate": self.sample_rate,
"duration": 1.0,
"error": str(e)
}
def numpy_to_wav(self, audio_np, sample_rate):
"""Convertit numpy array en WAV bytes"""
import wave
import struct
buffer = io.BytesIO()
with wave.open(buffer, 'wb') as wav_file:
wav_file.setnchannels(1) # Mono
wav_file.setsampwidth(2) # 16-bit
wav_file.setframerate(sample_rate)
# Convertir en int16
audio_int16 = (audio_np * 32767).astype(np.int16)
wav_file.writeframes(audio_int16.tobytes())
buffer.seek(0)
return buffer.read()