test / app.py
WWMachine's picture
Update app.py
e290870 verified
raw
history blame
7.18 kB
import gradio as gr
from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import os
import re
import time
from deepgram import DeepgramClient, PrerecordedOptions, SpeakOptions
from pydub import AudioSegment # Added for audio stitching
# --- Configuration ---
DEEPGRAM_API_KEY = "19d640a011569d78395c814e5f875b15cc84deb8"
REPO_ID = "Kezovic/iris-q4gguf-v2"
FILENAME = "llama-3.2-1b-instruct.Q4_K_M.gguf"
CONTEXT_WINDOW = 4096
MAX_NEW_TOKENS = 512
TEMPERATURE = 0.7
# Deepgram Limit: Maximum 2000 characters per TTS request.
TTS_MAX_CHARS = 1900 # Use slightly less than max for safety
# --- Initialize Deepgram & LLM ---
deepgram = DeepgramClient(DEEPGRAM_API_KEY) if DEEPGRAM_API_KEY else None
llm = None
def load_llm():
global llm
try:
model_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
llm = Llama(
model_path=model_path,
n_ctx=CONTEXT_WINDOW,
n_threads=2,
verbose=False
)
except Exception as e:
print(f"Error loading LLM: {e}")
load_llm()
# --- Helper Functions for Splitting ---
def split_text_for_tts(text, max_chars=TTS_MAX_CHARS):
"""Splits text into chunks <= max_chars based on punctuation for natural TTS."""
# Split on strong delimiters (period, question mark, exclamation mark, newline)
# The delimiters are kept in the segments by using parentheses
segments = re.split(r'([.?!]\s+|\n+)', text)
chunks = []
current_chunk = ""
for segment in segments:
if len(current_chunk) + len(segment) < max_chars:
current_chunk += segment
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = segment
if current_chunk:
chunks.append(current_chunk.strip())
return [chunk for chunk in chunks if chunk]
# --- 1. Speech-to-Text (STT) with File Size Check ---
def transcribe(audio_path):
"""Converts Speech to Text using Deepgram, with a file size check."""
if not audio_path or deepgram is None:
return None
# STT API check: Deepgram Pre-Recorded supports files up to 2GB
# We check file size and return a warning if too large (e.g., > 200MB, where asynchronous processing is better)
file_size_bytes = os.path.getsize(audio_path)
if file_size_bytes > 200 * 1024 * 1024:
print("Warning: Audio file is large. Transcription may take a moment.")
try:
with open(audio_path, "rb") as buffer:
payload = {"buffer": buffer}
options = PrerecordedOptions(
smart_format=True, model="nova-2", language="en-US",
# Add diarization=True if you want speaker separation in the transcript
)
response = deepgram.listen.rest.v("1").transcribe_file(payload, options)
return response.results.channels[0].alternatives[0].transcript
except Exception as e:
print(f"STT Error: {e}")
return None
# --- 2. Text-to-Speech (TTS) with Stitching ---
def text_to_speech(text):
"""Converts Text to Speech, splitting long text and stitching audio."""
if deepgram is None:
return None
# Step A: Split text into small chunks
text_chunks = split_text_for_tts(text)
audio_segments = []
# Step B: Generate audio for each chunk
for i, chunk in enumerate(text_chunks):
try:
temp_filename = f"temp_tts_chunk_{i}_{int(time.time())}.wav"
options = SpeakOptions(
model="aura-2-phoebe-en", encoding="linear16", container="wav"
)
deepgram.speak.rest.v("1").save(temp_filename, {"text": chunk}, options)
# Load the temporary audio into pydub
audio_segments.append(AudioSegment.from_wav(temp_filename))
os.remove(temp_filename)
except Exception as e:
print(f"TTS API FAILED for chunk {i}: {e}. Skipping chunk.")
continue
if not audio_segments:
return None
# Step C: Stitch the audio files together
stitched_audio = audio_segments[0]
for i in range(1, len(audio_segments)):
# Add a 200ms pause between sentences for better flow
stitched_audio += AudioSegment.silent(duration=200)
stitched_audio += audio_segments[i]
# Step D: Export the final stitched file
final_filename = f"final_response_{int(time.time())}.wav"
stitched_audio.export(final_filename, format="wav")
return final_filename
# --- Main Chat Logic (Same as before) ---
def run_chat_pipeline(audio_input, history, state_messages):
if llm is None:
return history, state_messages, None
# 1. Transcribe Audio (STT)
user_text = transcribe(audio_input)
if not user_text:
# If transcription fails (e.g., bad audio, API key error), inform the user via the chat.
history.append(("", "System Error: Could not process audio. Check API Key or try speaking louder."))
return history, state_messages, None
# 2. Update UI and State with User Message
state_messages.append({"role": "user", "content": user_text})
history.append((user_text, None))
# 3. LLM Generation (Contextual)
try:
completion = llm.create_chat_completion(
messages=state_messages,
max_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE
)
ai_text = completion['choices'][0]['message']['content']
except Exception as e:
ai_text = f"LLM Generation Error: {str(e)}"
# 4. Update UI and State with AI Response
state_messages.append({"role": "assistant", "content": ai_text})
history[-1] = (user_text, ai_text)
# 5. Generate Audio (TTS with splitting)
audio_path = text_to_speech(ai_text) # This handles the stitching
return history, state_messages, audio_path, None
# --- Gradio UI Layout ---
with gr.Blocks(title="Voice Chatbot") as demo:
gr.Markdown("## 🎙️ Voice-First AI Chat (Memory & Long-Text Handled)")
chatbot = gr.Chatbot(label="Conversation", height=500)
state_messages = gr.State([])
with gr.Row():
with gr.Column(scale=4):
audio_input = gr.Audio(
sources=["microphone"],
type="filepath",
label="Record Your Message"
)
with gr.Column(scale=1):
submit_btn = gr.Button("Send Voice 💬", variant="primary")
clear_btn = gr.Button("Clear Memory 🗑️")
audio_player = gr.Audio(
label="AI Voice",
autoplay=True,
interactive=False
)
# --- Event Wiring ---
submit_btn.click(
fn=run_chat_pipeline,
inputs=[audio_input, chatbot, state_messages],
outputs=[chatbot, state_messages, audio_player, audio_input]
)
def clear_all():
return [], [], None
clear_btn.click(
fn=clear_all,
inputs=None,
outputs=[chatbot, state_messages, audio_player]
)
if __name__ == "__main__":
demo.launch()