Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import os | |
| import logging | |
| import soundfile as sf | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| from TTS.tts.models.xtts import Xtts | |
| # --- CONSTANTES --- | |
| REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference" | |
| LOCAL_DIR = "./models" | |
| class WolofXTTSInference: | |
| def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR): | |
| # Configuration du logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s' | |
| ) | |
| self.logger = logging.getLogger(__name__) | |
| # Créer le dossier local s'il n'existe pas | |
| os.makedirs(local_dir, exist_ok=True) | |
| # Téléchargement des fichiers nécessaires | |
| try: | |
| # Créer les sous-dossiers nécessaires | |
| os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True) | |
| os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True) | |
| # Télécharger le checkpoint | |
| self.model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="Anta_GPT_XTTS_Wo/best_model_89250.pth", | |
| local_dir=local_dir | |
| ) | |
| # Télécharger le fichier de configuration | |
| self.config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="Anta_GPT_XTTS_Wo/config.json", | |
| local_dir=local_dir | |
| ) | |
| # Télécharger le vocabulaire | |
| self.vocab_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="XTTS_v2.0_original_model_files/vocab.json", | |
| local_dir=local_dir | |
| ) | |
| # Télécharger l'audio de référence | |
| self.reference_audio = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="anta_sample.wav", | |
| local_dir=local_dir | |
| ) | |
| except Exception as e: | |
| self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}") | |
| raise | |
| # Sélection du device | |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # Initialisation du modèle | |
| self.model = self._load_model() | |
| def _load_model(self): | |
| """Charge le modèle XTTS""" | |
| try: | |
| self.logger.info("Chargement du modèle XTTS...") | |
| # Initialisation du modèle | |
| config = XttsConfig() | |
| config.load_json(self.config_path) | |
| model = Xtts.init_from_config(config) | |
| # Chargement du checkpoint avec load_checkpoint | |
| model.load_checkpoint(config, | |
| checkpoint_path=self.model_path, | |
| vocab_path=self.vocab_path, | |
| use_deepspeed=False | |
| ) | |
| model.to(self.device) | |
| model.eval() # Mettre le modèle en mode évaluation | |
| self.logger.info("Modèle chargé avec succès!") | |
| return model | |
| except Exception as e: | |
| self.logger.error(f"Erreur lors du chargement du modèle : {e}") | |
| raise | |
| def generate_audio( | |
| self, | |
| text: str, | |
| reference_audio: str = None, | |
| speed: float = 1.06, | |
| language: str = "wo", | |
| output_path: str = None | |
| ) -> tuple[np.ndarray, int]: | |
| """ | |
| Génère de l'audio à partir du texte fourni | |
| Args: | |
| text (str): Texte à convertir en audio | |
| reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None. | |
| speed (float, optional): Vitesse de lecture. Defaults to 1.06. | |
| language (str, optional): Langue du texte. Defaults to "wo". | |
| output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. | |
| Returns: | |
| tuple[np.ndarray, int]: audio_array, sample_rate | |
| """ | |
| if not text: | |
| raise ValueError("Le texte ne peut pas être vide.") | |
| try: | |
| # Utiliser l'audio de référence fourni ou par défaut | |
| ref_audio = reference_audio or self.reference_audio | |
| # Obtenir les embeddings | |
| gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents( | |
| audio_path=[ref_audio], | |
| gpt_cond_len=self.model.config.gpt_cond_len, | |
| max_ref_length=self.model.config.max_ref_len, | |
| sound_norm_refs=self.model.config.sound_norm_refs | |
| ) | |
| # Génération de l'audio | |
| result = self.model.inference( | |
| text=text.lower(), | |
| gpt_cond_latent=gpt_cond_latent, | |
| speaker_embedding=speaker_embedding, | |
| do_sample=False, | |
| speed=speed, | |
| language=language, | |
| enable_text_splitting=True | |
| ) | |
| # Récupérer le taux d'échantillonnage | |
| sample_rate = self.model.config.audio.sample_rate | |
| # Sauvegarde optionnelle | |
| if output_path: | |
| sf.write(output_path, result["wav"], sample_rate) | |
| self.logger.info(f"Audio sauvegardé dans {output_path}") | |
| return result["wav"], sample_rate | |
| except Exception as e: | |
| self.logger.error(f"Erreur lors de la génération de l'audio : {e}") | |
| raise | |
| def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]: | |
| """ | |
| Génère de l'audio à partir du texte et d'un dictionnaire de configuration. | |
| Args: | |
| text (str): Texte à convertir en audio | |
| config (dict): Dictionnaire de configuration (speed, language, reference_audio) | |
| output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None. | |
| Returns: | |
| tuple[np.ndarray, int]: audio_array, sample_rate | |
| """ | |
| speed = config.get('speed', 1.06) | |
| language = config.get('language', "wo") | |
| reference_audio = config.get('reference_audio', None) | |
| return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path) | |
| # Exemple d'utilisation | |
| if __name__ == "__main__": | |
| tts = WolofXTTSInference() | |
| # Exemple de génération d'audio | |
| text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!" | |
| # Simple | |
| audio, sr = tts.generate_audio( | |
| text, | |
| output_path="generated_audio.wav" | |
| ) | |
| # Avec une config | |
| config_gen_audio = { | |
| "speed": 1.2, | |
| "language": "wo", | |
| } | |
| audio, sr = tts.generate_audio_from_config( | |
| text=text, | |
| config=config_gen_audio, | |
| output_path="generated_audio_config.wav" | |
| ) |