| """ |
| Handler simplifié pour Kyutai TTS - Version minimaliste |
| """ |
|
|
| import torch |
| import base64 |
| import io |
| import numpy as np |
| from typing import Dict, Any |
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| """ |
| Initialise le handler de manière simplifiée |
| """ |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"🔧 Initialisation sur {self.device}") |
| |
| try: |
| |
| from moshi.models import loaders |
| print("📥 Chargement du modèle avec moshi...") |
| self.lm_model = loaders.get_pretrained_lm_model( |
| device=self.device, |
| repo_id="kyutai/tts-1.6b-en_fr" |
| ) |
| self.use_moshi = True |
| print("✅ Modèle chargé avec moshi!") |
| except Exception as e: |
| print(f"⚠️ Erreur moshi: {e}") |
| print("📥 Chargement alternatif du modèle...") |
| |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| self.model = AutoModelForCausalLM.from_pretrained( |
| "kyutai/tts-1.6b-en_fr", |
| torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
| device_map="auto" |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained("kyutai/tts-1.6b-en_fr") |
| self.use_moshi = False |
| print("✅ Modèle chargé avec transformers!") |
| |
| self.sample_rate = 24000 |
| |
| def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Traite les requêtes TTS |
| """ |
| text = data.get("inputs", "") |
| if not text: |
| raise ValueError("Le paramètre 'inputs' est requis") |
| |
| params = data.get("parameters", {}) |
| language = params.get("language", "auto") |
| |
| |
| if language == "auto": |
| fr_chars = set("àâäéèêëïîôùûçœ") |
| has_french = any(c in text.lower() for c in fr_chars) |
| language = "fr" if has_french else "en" |
| |
| try: |
| print(f"🎤 Synthèse TTS: {len(text)} caractères en {language}") |
| |
| if self.use_moshi: |
| |
| with torch.no_grad(): |
| audio_tensor = self.lm_model.synthesize( |
| text=text, |
| language=language, |
| speaker_id=0, |
| speed=1.0 |
| ) |
| audio_np = audio_tensor.cpu().numpy() |
| else: |
| |
| print("⚠️ Mode fallback: audio de test") |
| duration = len(text) * 0.05 |
| t = np.linspace(0, duration, int(self.sample_rate * duration)) |
| |
| audio_np = 0.5 * np.sin(2 * np.pi * 440 * t) |
| |
| |
| audio_np = audio_np / (np.max(np.abs(audio_np)) + 1e-8) |
| |
| |
| audio_bytes = self.numpy_to_wav(audio_np, self.sample_rate) |
| audio_base64 = base64.b64encode(audio_bytes).decode('utf-8') |
| |
| return { |
| "audio": audio_base64, |
| "sampling_rate": self.sample_rate, |
| "duration": len(audio_np) / self.sample_rate |
| } |
| |
| except Exception as e: |
| print(f"❌ Erreur TTS: {str(e)}") |
| |
| silence = np.zeros(self.sample_rate) |
| audio_bytes = self.numpy_to_wav(silence, self.sample_rate) |
| return { |
| "audio": base64.b64encode(audio_bytes).decode('utf-8'), |
| "sampling_rate": self.sample_rate, |
| "duration": 1.0, |
| "error": str(e) |
| } |
| |
| def numpy_to_wav(self, audio_np, sample_rate): |
| """Convertit numpy array en WAV bytes""" |
| import wave |
| import struct |
| |
| buffer = io.BytesIO() |
| with wave.open(buffer, 'wb') as wav_file: |
| wav_file.setnchannels(1) |
| wav_file.setsampwidth(2) |
| wav_file.setframerate(sample_rate) |
| |
| |
| audio_int16 = (audio_np * 32767).astype(np.int16) |
| wav_file.writeframes(audio_int16.tobytes()) |
| |
| buffer.seek(0) |
| return buffer.read() |