Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| from huggingface_hub import InferenceClient | |
| from datasets import load_dataset | |
| import gradio as gr | |
| import os | |
| import numpy as np | |
| # Récupération du token (Assure-toi de l'avoir défini dans les Secrets du Space) | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Détection du hardware (GPU ou CPU) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Device utilisé : {device}") | |
| # --- 1. Modèles de Transcription (ASR) --- | |
| # Utilisation de distil-whisper pour plus de rapidité sur CPU/GPU léger | |
| transcriber = pipeline( | |
| "automatic-speech-recognition", | |
| model="openai/whisper-base.en", | |
| device=device | |
| ) | |
| # --- 2. Client LLM --- | |
| client = InferenceClient( | |
| token=HF_TOKEN | |
| ) | |
| # --- 3. Synthèse Vocale (TTS) --- | |
| processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") | |
| model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) | |
| # Chargement du speaker embedding (voix) | |
| embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True) | |
| speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device) | |
| def transcribe(audio_path): | |
| """Convertit l'audio (chemin de fichier) en texte.""" | |
| if audio_path is None: | |
| return "" | |
| # Whisper gère directement les chemins de fichiers envoyés par Gradio | |
| text = transcriber(audio_path)["text"] | |
| return text | |
| def query_llm(text): | |
| """Envoie le texte au LLM.""" | |
| if not text: | |
| return "Je n'ai rien entendu." | |
| try: | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful vocal assistant. Keep your answers short and concise suitable for speech synthesis."}, | |
| {"role": "user", "content": text} | |
| ] | |
| completion = client.chat.completions.create( | |
| model="meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| messages=messages, | |
| max_tokens=150 | |
| ) | |
| return completion.choices[0].message.content | |
| except Exception as e: | |
| return f"Erreur LLM: {str(e)}" | |
| def synthesise(text): | |
| """Convertit le texte en audio.""" | |
| if not text: | |
| return None | |
| inputs = processor(text=text, return_tensors="pt") | |
| # Gestion de la taille du texte (SpeechT5 a une limite) | |
| if inputs["input_ids"].shape[1] > 600: | |
| text = text[:500] + "..." # Tronquer si trop long | |
| inputs = processor(text=text, return_tensors="pt") | |
| input_ids = inputs["input_ids"].to(device) | |
| with torch.no_grad(): | |
| speech = model.generate_speech( | |
| input_ids, | |
| speaker_embeddings, | |
| vocoder=vocoder | |
| ) | |
| # Retourne (Sampling Rate, Audio Array) | |
| return (16000, speech.cpu().numpy()) | |
| def process_pipeline(audio_path): | |
| """Fonction principale appelée par Gradio""" | |
| if audio_path is None: | |
| return "Aucun audio détecté", "...", None | |
| # 1. Transcription | |
| user_text = transcribe(audio_path) | |
| print(f"User: {user_text}") | |
| # 2. Réflexion (LLM) | |
| ai_response = query_llm(user_text) | |
| print(f"AI: {ai_response}") | |
| # 3. Synthèse (TTS) | |
| audio_result = synthesise(ai_response) | |
| return user_text, ai_response, audio_result | |
| # --- Interface Gradio --- | |
| with gr.Blocks(title="Assistant Vocal AI") as demo: | |
| gr.Markdown("## 🎙️ Assistant Vocal Llama & Whisper") | |
| gr.Markdown("Parlez dans le micro, l'IA va transcrire, réfléchir et vous répondre oralement.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Votre voix") | |
| submit_btn = gr.Button("Envoyer", variant="primary") | |
| with gr.Column(): | |
| transcription_box = gr.Textbox(label="Ce que j'ai entendu") | |
| response_box = gr.Textbox(label="Réponse textuelle") | |
| audio_output = gr.Audio(label="Réponse vocale", autoplay=True) | |
| submit_btn.click( | |
| fn=process_pipeline, | |
| inputs=[audio_input], | |
| outputs=[transcription_box, response_box, audio_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |