quentinbch commited on
Commit
b70f4c1
·
verified ·
1 Parent(s): a4c158d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -117
app.py CHANGED
@@ -1,146 +1,133 @@
1
- from transformers import pipeline
2
  import torch
3
- from transformers.pipelines.audio_utils import ffmpeg_microphone_live
4
- from huggingface_hub import HfFolder, InferenceClient
5
- import requests
6
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
7
  from datasets import load_dataset
8
- import sounddevice as sd
9
- import sys
10
- import os
11
- from dotenv import load_dotenv
12
  import gradio as gr
13
- import warnings
 
14
 
15
- load_dotenv()
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
 
18
- warnings.filterwarnings("ignore",
19
- message="At least one mel filter has all zero values.*",
20
- category=UserWarning)
21
-
22
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
23
- classifier = pipeline(
24
- "audio-classification",
25
- model="MIT/ast-finetuned-speech-commands-v2",
26
- device=device
27
- )
28
-
29
- def launch_fn(wake_word="marvin", prob_threshold=0.5, chunk_length_s=2.0, stream_chunk_s=0.25, debug=False):
30
- if wake_word not in classifier.model.config.label2id.keys():
31
- raise ValueError(
32
- f"Wake word {wake_word} not in set of valid class labels, pick a wake word in the set {classifier.model.config.label2id.keys()}."
33
- )
34
-
35
- sampling_rate = classifier.feature_extractor.sampling_rate
36
-
37
- mic = ffmpeg_microphone_live(
38
- sampling_rate=sampling_rate,
39
- chunk_length_s=chunk_length_s,
40
- stream_chunk_s=stream_chunk_s,
41
- )
42
-
43
- print("Listening for wake word...")
44
- for prediction in classifier(mic):
45
- prediction = prediction[0]
46
- if debug:
47
- print(prediction)
48
- if prediction["label"] == wake_word:
49
- if prediction["score"] > prob_threshold:
50
- return True
51
 
 
 
52
  transcriber = pipeline(
53
- "automatic-speech-recognition", model="openai/whisper-base.en", device=device
 
 
54
  )
55
 
56
- def transcribe(chunk_length_s=5.0, stream_chunk_s=1.0):
57
- sampling_rate = transcriber.feature_extractor.sampling_rate
58
-
59
- mic = ffmpeg_microphone_live(
60
- sampling_rate=sampling_rate,
61
- chunk_length_s=chunk_length_s,
62
- stream_chunk_s=stream_chunk_s,
63
- )
64
-
65
- print("Start speaking...")
66
- for item in transcriber(mic, generate_kwargs={"max_new_tokens": 128}):
67
- sys.stdout.write("\033[K")
68
- print(item["text"], end="\r")
69
- if not item["partial"][0]:
70
- break
71
-
72
- return item["text"]
73
-
74
-
75
-
76
  client = InferenceClient(
77
  provider="fireworks-ai",
78
  api_key=HF_TOKEN
79
  )
80
 
81
- def query(text, model_id="meta-llama/Llama-3.1-8B-Instruct"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  try:
 
 
 
 
 
 
83
  completion = client.chat.completions.create(
84
- model=model_id,
85
- messages=[{"role": "user", "content": text}]
 
86
  )
87
  return completion.choices[0].message.content
88
-
89
  except Exception as e:
90
- print(f"Erreur: {str(e)}")
91
- return None
92
-
93
-
94
-
95
- processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
96
- model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
97
- vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
98
 
99
- embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
100
- speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
 
 
 
 
 
 
 
 
 
101
 
 
102
 
103
- def synthesise(text):
104
- input_ids = processor(text=text, return_tensors="pt")["input_ids"]
105
- try:
106
  speech = model.generate_speech(
107
- input_ids.to(device),
108
- speaker_embeddings.to(device),
109
  vocoder=vocoder
110
  )
111
- return speech.cpu()
112
- except Exception as e:
113
- print(f"Erreur lors de la synthèse vocale : {e}")
114
- return None
115
-
116
- # launch_fn(debug=True)
117
- # transcription = transcribe()
118
- # response = query(transcription)
119
- # audio = synthesise(response)
120
- #
121
- # sd.play(audio.numpy(), 16000)
122
- # sd.wait()
123
-
124
- # Interface Gradio
125
- def assistant_vocal_interface():
126
- launch_fn(debug=True)
127
- transcription = transcribe()
128
- response = query(transcription)
129
- audio = synthesise(response)
130
- return transcription, response, (16000, audio.numpy())
131
-
132
- with gr.Blocks(title="Assistant Vocal") as demo:
133
- gr.Markdown("## Assistant vocal : détection, transcription, génération et synthèse")
134
-
135
- start_btn = gr.Button("Démarrer l'assistant")
136
- transcription_box = gr.Textbox(label="Transcription")
137
- response_box = gr.Textbox(label="Réponse IA")
138
- audio_output = gr.Audio(label="Synthèse vocale", type="numpy", autoplay=True)
139
-
140
- start_btn.click(
141
- assistant_vocal_interface,
142
- inputs=[],
 
 
 
 
 
 
 
 
143
  outputs=[transcription_box, response_box, audio_output]
144
  )
145
 
146
- demo.launch(share=True)
 
 
 
1
  import torch
2
+ from transformers import pipeline, SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
3
+ from huggingface_hub import InferenceClient
 
 
4
  from datasets import load_dataset
 
 
 
 
5
  import gradio as gr
6
+ import os
7
+ import numpy as np
8
 
9
+ # Récupération du token (Assure-toi de l'avoir défini dans les Secrets du Space)
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
 
12
+ # Détection du hardware (GPU ou CPU)
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Device utilisé : {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ # --- 1. Modèles de Transcription (ASR) ---
17
+ # Utilisation de distil-whisper pour plus de rapidité sur CPU/GPU léger
18
  transcriber = pipeline(
19
+ "automatic-speech-recognition",
20
+ model="openai/whisper-base.en",
21
+ device=device
22
  )
23
 
24
+ # --- 2. Client LLM (Intelligence) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  client = InferenceClient(
26
  provider="fireworks-ai",
27
  api_key=HF_TOKEN
28
  )
29
 
30
+ # --- 3. Synthèse Vocale (TTS) ---
31
+ processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
32
+ model = SpeechT5ForTextToSpeech.from_pretrained("microsoft/speecht5_tts").to(device)
33
+ vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
34
+
35
+ # Chargement du speaker embedding (voix)
36
+ embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
37
+ speaker_embeddings = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0).to(device)
38
+
39
+ def transcribe(audio_path):
40
+ """Convertit l'audio (chemin de fichier) en texte."""
41
+ if audio_path is None:
42
+ return ""
43
+
44
+ # Whisper gère directement les chemins de fichiers envoyés par Gradio
45
+ text = transcriber(audio_path)["text"]
46
+ return text
47
+
48
+ def query_llm(text):
49
+ """Envoie le texte au LLM."""
50
+ if not text:
51
+ return "Je n'ai rien entendu."
52
+
53
  try:
54
+ # Prompt système pour guider le modèle à être concis (mieux pour le TTS)
55
+ messages = [
56
+ {"role": "system", "content": "You are a helpful vocal assistant. Keep your answers short and concise suitable for speech synthesis."},
57
+ {"role": "user", "content": text}
58
+ ]
59
+
60
  completion = client.chat.completions.create(
61
+ model="accounts/fireworks/models/llama-v3p1-8b-instruct", # ID correct pour Fireworks via HF Client
62
+ messages=messages,
63
+ max_tokens=150 # Limite pour éviter une synthèse trop longue
64
  )
65
  return completion.choices[0].message.content
 
66
  except Exception as e:
67
+ return f"Erreur LLM: {str(e)}"
 
 
 
 
 
 
 
68
 
69
+ def synthesise(text):
70
+ """Convertit le texte en audio."""
71
+ if not text:
72
+ return None
73
+
74
+ inputs = processor(text=text, return_tensors="pt")
75
+
76
+ # Gestion de la taille du texte (SpeechT5 a une limite)
77
+ if inputs["input_ids"].shape[1] > 600:
78
+ text = text[:500] + "..." # Tronquer si trop long
79
+ inputs = processor(text=text, return_tensors="pt")
80
 
81
+ input_ids = inputs["input_ids"].to(device)
82
 
83
+ with torch.no_grad():
 
 
84
  speech = model.generate_speech(
85
+ input_ids,
86
+ speaker_embeddings,
87
  vocoder=vocoder
88
  )
89
+
90
+ # Retourne (Sampling Rate, Audio Array)
91
+ return (16000, speech.cpu().numpy())
92
+
93
+ def process_pipeline(audio_path):
94
+ """Fonction principale appelée par Gradio"""
95
+ if audio_path is None:
96
+ return "Aucun audio détecté", "...", None
97
+
98
+ # 1. Transcription
99
+ user_text = transcribe(audio_path)
100
+ print(f"User: {user_text}")
101
+
102
+ # 2. Réflexion (LLM)
103
+ ai_response = query_llm(user_text)
104
+ print(f"AI: {ai_response}")
105
+
106
+ # 3. Synthèse (TTS)
107
+ audio_result = synthesise(ai_response)
108
+
109
+ return user_text, ai_response, audio_result
110
+
111
+ # --- Interface Gradio ---
112
+ with gr.Blocks(title="Assistant Vocal AI") as demo:
113
+ gr.Markdown("## 🎙️ Assistant Vocal Llama & Whisper")
114
+ gr.Markdown("Parlez dans le micro, l'IA va transcrire, réfléchir et vous répondre oralement.")
115
+
116
+ with gr.Row():
117
+ with gr.Column():
118
+ audio_input = gr.Audio(sources=["microphone"], type="filepath", label="Votre voix")
119
+ submit_btn = gr.Button("Envoyer", variant="primary")
120
+
121
+ with gr.Column():
122
+ transcription_box = gr.Textbox(label="Ce que j'ai entendu")
123
+ response_box = gr.Textbox(label="Réponse textuelle")
124
+ audio_output = gr.Audio(label="Réponse vocale", autoplay=True)
125
+
126
+ submit_btn.click(
127
+ fn=process_pipeline,
128
+ inputs=[audio_input],
129
  outputs=[transcription_box, response_box, audio_output]
130
  )
131
 
132
+ if __name__ == "__main__":
133
+ demo.launch()