Voice_Assistant / app.py
quentinbch's picture
Update app.py
17036d1 verified
import torch
from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from huggingface_hub import InferenceClient
from datasets import load_dataset
import gradio as gr
import os
import numpy as np
# Récupération du token (Assure-toi de l'avoir défini dans les Secrets du Space)
HF_TOKEN = os.getenv("HF_TOKEN")
# Détection du hardware (GPU ou CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device utilisé : {device}")
# --- 1. Modèles de Transcription (ASR) ---
# Utilisation de distil-whisper pour plus de rapidité sur CPU/GPU léger
transcriber = pipeline(
"automatic-speech-recognition",
model="openai/whisper-base.en",
device=device
)
# --- 2. Client LLM ---
client = InferenceClient(
token=HF_TOKEN
)
# --- 3. Synthèse Vocale (TTS) ---
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
# Chargement du speaker embedding (voix)
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True)
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
def transcribe(audio_path):
"""Convertit l'audio (chemin de fichier) en texte."""
if audio_path is None:
return ""
# Whisper gère directement les chemins de fichiers envoyés par Gradio
text = transcriber(audio_path)["text"]
return text
def query_llm(text):
"""Envoie le texte au LLM."""
if not text:
return "Je n'ai rien entendu."
try:
messages = [
{"role": "system", "content": "You are a helpful vocal assistant. Keep your answers short and concise suitable for speech synthesis."},
{"role": "user", "content": text}
]
completion = client.chat.completions.create(
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
messages=messages,
max_tokens=150
)
return completion.choices[0].message.content
except Exception as e:
return f"Erreur LLM: {str(e)}"
def synthesise(text):
"""Convertit le texte en audio."""
if not text:
return None
inputs = processor(text=text, return_tensors="pt")
# Gestion de la taille du texte (SpeechT5 a une limite)
if inputs["input_ids"].shape[1] > 600:
text = text[:500] + "..." # Tronquer si trop long
inputs = processor(text=text, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)
with torch.no_grad():
speech = model.generate_speech(
input_ids,
speaker_embeddings,
vocoder=vocoder
)
# Retourne (Sampling Rate, Audio Array)
return (16000, speech.cpu().numpy())
def process_pipeline(audio_path):
"""Fonction principale appelée par Gradio"""
if audio_path is None:
return "Aucun audio détecté", "...", None
# 1. Transcription
user_text = transcribe(audio_path)
print(f"User: {user_text}")
# 2. Réflexion (LLM)
ai_response = query_llm(user_text)
print(f"AI: {ai_response}")
# 3. Synthèse (TTS)
audio_result = synthesise(ai_response)
return user_text, ai_response, audio_result
# --- Interface Gradio ---
with gr.Blocks(title="Assistant Vocal AI") as demo:
gr.Markdown("## 🎙️ Assistant Vocal Llama & Whisper")
gr.Markdown("Parlez dans le micro, l'IA va transcrire, réfléchir et vous répondre oralement.")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Votre voix")
submit_btn = gr.Button("Envoyer", variant="primary")
with gr.Column():
transcription_box = gr.Textbox(label="Ce que j'ai entendu")
response_box = gr.Textbox(label="Réponse textuelle")
audio_output = gr.Audio(label="Réponse vocale", autoplay=True)
submit_btn.click(
fn=process_pipeline,
inputs=[audio_input],
outputs=[transcription_box, response_box, audio_output]
)
if __name__ == "__main__":
demo.launch()