File size: 4,309 Bytes
31caba0
b70f4c1
 
31caba0
 
b70f4c1
 
31caba0
b70f4c1
31caba0
 
b70f4c1
 
 
31caba0
b70f4c1
 
31caba0
b70f4c1
 
 
31caba0
 
3b31dcb
31caba0
3b31dcb
31caba0
 
b70f4c1
 
 
 
 
 
8f1fe8d
b70f4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31caba0
b70f4c1
 
 
 
 
31caba0
17036d1
b70f4c1
17036d1
31caba0
 
 
b70f4c1
31caba0
b70f4c1
 
 
 
 
 
 
 
 
 
 
31caba0
b70f4c1
31caba0
b70f4c1
3420952
b70f4c1
 
3420952
 
b70f4c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31caba0
 
 
b70f4c1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from huggingface_hub import InferenceClient
from datasets import load_dataset
import gradio as gr
import os
import numpy as np

# Récupération du token (Assure-toi de l'avoir défini dans les Secrets du Space)
HF_TOKEN = os.getenv("HF_TOKEN")

# Détection du hardware (GPU ou CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device utilisé : {device}")

# --- 1. Modèles de Transcription (ASR) ---
# Utilisation de distil-whisper pour plus de rapidité sur CPU/GPU léger
transcriber = pipeline(
    "automatic-speech-recognition", 
    model="openai/whisper-base.en", 
    device=device
)

# --- 2. Client LLM ---
client = InferenceClient(
    token=HF_TOKEN
)

# --- 3. Synthèse Vocale (TTS) ---
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)

# Chargement du speaker embedding (voix)
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation", trust_remote_code=True)
speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)

def transcribe(audio_path):
    """Convertit l'audio (chemin de fichier) en texte."""
    if audio_path is None:
        return ""
    
    # Whisper gère directement les chemins de fichiers envoyés par Gradio
    text = transcriber(audio_path)["text"]
    return text

def query_llm(text):
    """Envoie le texte au LLM."""
    if not text:
        return "Je n'ai rien entendu."
    
    try:
        messages = [
            {"role": "system", "content": "You are a helpful vocal assistant. Keep your answers short and concise suitable for speech synthesis."},
            {"role": "user", "content": text}
        ]
        
        completion = client.chat.completions.create(
            model="meta-llama/Meta-Llama-3.1-8B-Instruct", 
            messages=messages,
            max_tokens=150
        )
        return completion.choices[0].message.content
    except Exception as e:
        return f"Erreur LLM: {str(e)}"

def synthesise(text):
    """Convertit le texte en audio."""
    if not text:
        return None
        
    inputs = processor(text=text, return_tensors="pt")
    
    # Gestion de la taille du texte (SpeechT5 a une limite)
    if inputs["input_ids"].shape[1] > 600:
        text = text[:500] + "..." # Tronquer si trop long
        inputs = processor(text=text, return_tensors="pt")

    input_ids = inputs["input_ids"].to(device)

    with torch.no_grad():
        speech = model.generate_speech(
            input_ids,
            speaker_embeddings,
            vocoder=vocoder
        )
    
    # Retourne (Sampling Rate, Audio Array)
    return (16000, speech.cpu().numpy())

def process_pipeline(audio_path):
    """Fonction principale appelée par Gradio"""
    if audio_path is None:
        return "Aucun audio détecté", "...", None

    # 1. Transcription
    user_text = transcribe(audio_path)
    print(f"User: {user_text}")

    # 2. Réflexion (LLM)
    ai_response = query_llm(user_text)
    print(f"AI: {ai_response}")

    # 3. Synthèse (TTS)
    audio_result = synthesise(ai_response)

    return user_text, ai_response, audio_result

# --- Interface Gradio ---
with gr.Blocks(title="Assistant Vocal AI") as demo:
    gr.Markdown("## 🎙️ Assistant Vocal Llama & Whisper")
    gr.Markdown("Parlez dans le micro, l'IA va transcrire, réfléchir et vous répondre oralement.")

    with gr.Row():
        with gr.Column():
            audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Votre voix")
            submit_btn = gr.Button("Envoyer", variant="primary")
        
        with gr.Column():
            transcription_box = gr.Textbox(label="Ce que j'ai entendu")
            response_box = gr.Textbox(label="Réponse textuelle")
            audio_output = gr.Audio(label="Réponse vocale", autoplay=True)

    submit_btn.click(
        fn=process_pipeline,
        inputs=[audio_input],
        outputs=[transcription_box, response_box, audio_output]
    )

if __name__ == "__main__":
    demo.launch()