test-voice-chat / voicechat-gradio.py
willsh1997's picture
FUCK OFF
7bfaef7 verified
import io
import os
import time
from dataclasses import dataclass, field
from multiprocessing import freeze_support
# import groq
import gradio as gr
import numpy as np
import spaces
from whisper_support import transcribe
from kokoro_support import generate_tts
from transformers import pipeline
import torch
# Global variable to hold the pipeline
llama3_pipe = None
default_sys_prompt = """You are a helpful chatbot. You respond very conversationally, and help the end user as best as you can."""
llama3_model_id = "shuyuej/Llama-3.2-1B-Instruct-GPTQ"
llama3_pipe = pipeline(
"text-generation",
model=llama3_model_id,
device_map="auto",
torch_dtype=torch.float16,
max_new_tokens=512
)
@spaces.GPU
def llama_QA(message_history, system_prompt: str):
"""
Function for asking llama a question and then getting an answer
inputs:
- message_history [list]: conversation history
- system_prompt [str]: system prompt for the model
outputs:
- response [str]: llama's response
"""
global llama3_pipe
# Prepare the message history
input_message_history = [{"role": "system", "content": system_prompt}]
input_message_history.extend(message_history)
# Generate response using pipeline
outputs = llama3_pipe(
input_message_history,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9
)
# Extract the response text
response = outputs[0]["generated_text"][-1]["content"]
return response
@dataclass
class AppState:
conversation: list = field(default_factory=list)
stopped: bool = False
model_outs: any = None
def process_audio(audio: tuple, state: AppState):
return audio, state
@spaces.GPU(duration=40, progress=gr.Progress(track_tqdm=True))
def response(state: AppState, audio: tuple, system_prompt):
if not audio:
return state, state.conversation, None
# Transcribe the audio file
transcription = transcribe(audio)
if transcription:
if transcription.startswith("Error"):
transcription = "Error in audio transcription."
# Append the user's message in the proper format
state.conversation.append({"role": "user", "content": transcription})
# Generate assistant response
assistant_message = llama_QA(state.conversation, system_prompt)
# Append the assistant's message in the proper format
state.conversation.append({"role": "assistant", "content": assistant_message})
# Generate TTS audio
response_audio = generate_tts(assistant_message)
print(state.conversation)
return state, state.conversation, response_audio
def start_recording_user(state: AppState):
return None
js = """
async function main() {
const script1 = document.createElement("script");
script1.src = "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.14.0/dist/ort.js";
document.head.appendChild(script1)
const script2 = document.createElement("script");
script2.onload = async () => {
console.log("vad loaded") ;
var record = document.querySelector('.record-button');
record.textContent = "Just Start Talking!"
record.style = "width: fit-content; padding-right: 0.5vw;"
const myvad = await vad.MicVAD.new({
onSpeechStart: () => {
var record = document.querySelector('.record-button');
var player = document.querySelector('#streaming-out audio');
if (record != null && (player == null || player.paused || player.ended)) {
console.log("Starting recording", record);
record.click();
} else {
console.log("Audio still playing, not starting recording");
}
},
onSpeechEnd: (audio) => {
var stop = document.querySelector('.stop-button');
if (stop != null) {
console.log("Stopping recording", stop);
stop.click();
}
}
})
myvad.start()
}
script2.src = "https://cdn.jsdelivr.net/npm/@ricky0123/vad-web@0.0.7/dist/bundle.min.js";
script1.onload = () => {
console.log("onnx loaded")
document.head.appendChild(script2)
};
}
"""
js_reset = """
() => {
var record = document.querySelector('.record-button');
record.textContent = "Just Start Talking!"
record.style = "width: fit-content; padding-right: 0.5vw;"
}
"""
def create_demo():
"""Create and return the Gradio demo interface"""
with gr.Blocks(js=js) as demo:
with gr.Row():
system_prompt = gr.Textbox(
value=default_sys_prompt,
interactive=True
)
with gr.Row():
input_audio = gr.Audio(
label="Input Audio",
sources=["microphone"],
type="numpy",
streaming=False,
)
with gr.Row():
chatbot = gr.Chatbot(label="Conversation", type="messages")
with gr.Row():
output_audio = gr.Audio(
label="Assistant Audio",
interactive=False,
autoplay=True,
elem_id="streaming-out"
)
state = gr.State(value=AppState())
stream = input_audio.start_recording(
process_audio,
[input_audio, state],
[input_audio, state],
)
respond = input_audio.stop_recording(
response,
inputs=[state, input_audio, system_prompt],
outputs=[state, chatbot, output_audio]
)
restart = respond.then(
start_recording_user,
[state],
[input_audio]
).then(
lambda state: state,
state,
state,
js=js_reset
)
cancel = gr.Button("New Conversation", variant="stop")
cancel.click(
lambda: (AppState(), gr.Audio(recording=False)),
None,
[state, input_audio],
cancels=[respond, restart],
)
return demo
demo = create_demo()
demo.launch()