|
|
""" |
|
|
Handler final pour Kyutai TTS - Compatible HF Endpoints |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import base64 |
|
|
import io |
|
|
import numpy as np |
|
|
from typing import Dict, Any |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
""" |
|
|
Initialise le handler avec un fallback audio simple |
|
|
""" |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"🔧 Initialisation sur {self.device}") |
|
|
|
|
|
|
|
|
self.sample_rate = 24000 |
|
|
self.model_loaded = False |
|
|
|
|
|
try: |
|
|
|
|
|
from moshi.models import loaders |
|
|
print("📥 Tentative de chargement avec moshi...") |
|
|
self.lm_model = loaders.get_pretrained_lm_model( |
|
|
device=self.device, |
|
|
repo_id="kyutai/tts-1.6b-en_fr" |
|
|
) |
|
|
self.model_loaded = True |
|
|
print("✅ Modèle Kyutai chargé avec succès!") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Impossible de charger Kyutai TTS: {e}") |
|
|
print("🔄 Mode fallback activé - génération audio basique") |
|
|
self.model_loaded = False |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Traite les requêtes TTS |
|
|
""" |
|
|
text = data.get("inputs", "") |
|
|
if not text: |
|
|
raise ValueError("Le paramètre 'inputs' est requis") |
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
language = params.get("language", "auto") |
|
|
|
|
|
|
|
|
if language == "auto": |
|
|
fr_chars = set("àâäéèêëïîôùûçœ") |
|
|
has_french = any(c in text.lower() for c in fr_chars) |
|
|
language = "fr" if has_french else "en" |
|
|
|
|
|
try: |
|
|
if self.model_loaded: |
|
|
|
|
|
print(f"🎤 Synthèse Kyutai TTS: {len(text)} caractères en {language}") |
|
|
with torch.no_grad(): |
|
|
audio_tensor = self.lm_model.synthesize( |
|
|
text=text, |
|
|
language=language, |
|
|
speaker_id=0, |
|
|
speed=1.0 |
|
|
) |
|
|
audio_np = audio_tensor.cpu().numpy() |
|
|
else: |
|
|
|
|
|
print(f"🎵 Mode fallback: génération audio simple pour {len(text)} caractères") |
|
|
duration = min(len(text) * 0.06, 10.0) |
|
|
samples = int(self.sample_rate * duration) |
|
|
|
|
|
|
|
|
t = np.linspace(0, duration, samples) |
|
|
|
|
|
f1 = 200 + 50 * np.sin(2 * np.pi * 3 * t) |
|
|
f2 = 400 + 100 * np.sin(2 * np.pi * 2 * t) |
|
|
|
|
|
|
|
|
audio_np = 0.3 * np.sin(2 * np.pi * f1 * t) |
|
|
audio_np += 0.2 * np.sin(2 * np.pi * f2 * t) |
|
|
audio_np += 0.1 * np.sin(2 * np.pi * 800 * t) |
|
|
|
|
|
|
|
|
envelope = np.exp(-t / duration * 2) |
|
|
audio_np *= envelope |
|
|
|
|
|
|
|
|
if np.max(np.abs(audio_np)) > 0: |
|
|
audio_np = audio_np / np.max(np.abs(audio_np)) * 0.8 |
|
|
|
|
|
|
|
|
audio_bytes = self.numpy_to_wav(audio_np, self.sample_rate) |
|
|
audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') |
|
|
|
|
|
return { |
|
|
"audio": audio_base64, |
|
|
"sampling_rate": self.sample_rate, |
|
|
"duration": len(audio_np) / self.sample_rate, |
|
|
"model_loaded": self.model_loaded, |
|
|
"language": language |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Erreur lors de la synthèse: {str(e)}") |
|
|
|
|
|
silence = np.zeros(int(self.sample_rate * 0.5)) |
|
|
audio_bytes = self.numpy_to_wav(silence, self.sample_rate) |
|
|
return { |
|
|
"audio": base64.b64encode(audio_bytes).decode('utf-8'), |
|
|
"sampling_rate": self.sample_rate, |
|
|
"duration": 0.5, |
|
|
"error": str(e), |
|
|
"model_loaded": self.model_loaded |
|
|
} |
|
|
|
|
|
def numpy_to_wav(self, audio_np, sample_rate): |
|
|
"""Convertit numpy array en WAV bytes""" |
|
|
import struct |
|
|
|
|
|
|
|
|
if audio_np.ndim > 1: |
|
|
audio_np = audio_np.flatten() |
|
|
|
|
|
|
|
|
audio_int16 = (audio_np * 32767).astype(np.int16) |
|
|
|
|
|
|
|
|
num_samples = len(audio_int16) |
|
|
num_channels = 1 |
|
|
bits_per_sample = 16 |
|
|
byte_rate = sample_rate * num_channels * bits_per_sample // 8 |
|
|
block_align = num_channels * bits_per_sample // 8 |
|
|
|
|
|
|
|
|
wav_header = struct.pack( |
|
|
'<4sI4s4sIHHIIHH4sI', |
|
|
b'RIFF', |
|
|
36 + num_samples * 2, |
|
|
b'WAVE', |
|
|
b'fmt ', |
|
|
16, |
|
|
1, |
|
|
num_channels, |
|
|
sample_rate, |
|
|
byte_rate, |
|
|
block_align, |
|
|
bits_per_sample, |
|
|
b'data', |
|
|
num_samples * 2 |
|
|
) |
|
|
|
|
|
|
|
|
wav_data = wav_header + audio_int16.tobytes() |
|
|
|
|
|
return wav_data |