Audio_test / app.py
afouda's picture
Update app.py
213fb68 verified
import asyncio
import base64
import os
import time
from io import BytesIO
from google.genai import types
from google.genai.types import (
LiveConnectConfig,
SpeechConfig,
VoiceConfig,
PrebuiltVoiceConfig,
Content,
Part,
)
import gradio as gr
import numpy as np
import websockets
from dotenv import load_dotenv
from fastrtc import (
AsyncAudioVideoStreamHandler,
Stream,
WebRTC,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from gradio.utils import get_space
from PIL import Image
# ------------------------------------------
import asyncio
import base64
import json
import os
import pathlib
from typing import AsyncGenerator, Literal
import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastrtc import (
AsyncStreamHandler,
Stream,
get_cloudflare_turn_credentials_async,
wait_for_item,
)
from google import genai
from google.genai.types import (
LiveConnectConfig,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,
)
from gradio.utils import get_space
from pydantic import BaseModel
# ------------------------------------------------
from dotenv import load_dotenv
load_dotenv()
import os
import io
import asyncio
from pydub import AudioSegment
async def safe_get_ice_config_async():
"""Return Cloudflare TURN credentials when available, otherwise return a STUN-only fallback.
This prevents the library from raising the HF_TOKEN / CLOUDFLARE_* error when those
environment variables are not set during local testing.
"""
# If HuggingFace token or Cloudflare TURN env vars are present, try to use the helper
if os.getenv("HF_TOKEN") or (os.getenv("CLOUDFLARE_TURN_KEY_ID") and os.getenv("CLOUDFLARE_TURN_KEY_API_TOKEN")):
try:
return await get_cloudflare_turn_credentials_async()
except Exception as e:
print("Warning: failed to get Cloudflare TURN credentials, falling back to STUN-only. Error:", e)
# Fallback: return minimal STUN servers so WebRTC can still attempt peer connections locally
return {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
# Gemini: google-genai
from google import genai
# ---------------------------------------------------
# VAD imports from reference code
import collections
import webrtcvad
import time
# Weaviate imports
import weaviate
from weaviate.classes.init import Auth
from contextlib import contextmanager
# helper functions
GEMINI_API_KEY="AIzaSyATK7Q1xqWLa7nw1Y40mvRrB8motyQl1oo"
HF_TOKEN ="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"
WEAVIATE_URL="18vysvlxqza0ux821ecbg.c0.us-west3.gcp.weaviate.cloud"
WEAVIATE_API_KEY="b2d4dC9sV1Y0dkZjSnlkRV9EMU04V0FyRE9HSlBPQnhlbENzQ0dWQm9pbENyRUVuWXpWc3R3YmpjK1pBPV92MjAw"
DEEPINFRA_API_KEY="285LUJulGIprqT6hcPhiXtcrphU04FG4"
DEEPINFRA_BASE_URL="https://api.deepinfra.com/v1/openai"
from openai import OpenAI
openai = OpenAI(
api_key=DEEPINFRA_API_KEY,
base_url="https://api.deepinfra.com/v1/openai",
)
@contextmanager
def weaviate_client():
"""
Context manager that yields a Weaviate client and
guarantees client.close() on exit.
"""
client = weaviate.connect_to_weaviate_cloud(
cluster_url=WEAVIATE_URL,
auth_credentials=Auth.api_key(WEAVIATE_API_KEY),
)
try:
yield client
finally:
client.close()
def encode_audio(data: np.ndarray) -> dict:
"""Encode Audio data to send to the server"""
return {
"mime_type": "audio/pcm",
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),
}
def encode_audio2(data: np.ndarray) -> bytes:
"""Encode Audio data to send to the server"""
return data.tobytes()
import soundfile as sf
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
buffer = io.BytesIO()
sf.write(buffer, audio_array, sample_rate, format='WAV')
return buffer.getvalue()
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
"""
Convert a NumPy audio array to WAV bytes.
Args:
audio_array (np.ndarray): Audio signal (1D or 2D).
sample_rate (int): Sample rate in Hz.
Returns:
bytes: WAV-formatted audio data.
"""
buffer = io.BytesIO()
sf.write(buffer, audio_array, sample_rate, format='WAV')
buffer.seek(0)
return buffer.read()
# webrtc handler class
class GeminiHandler(AsyncStreamHandler):
"""Handler for the Gemini API with chained latency calculation."""
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,prompt_dict: dict = {"prompt":"PHQ-9"},
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
self.prompt_dict = prompt_dict
# self.model = "gemini-2.5-flash-preview-tts"
self.model = "gemini-2.0-flash-live-001"
self.t2t_model = "gemini-2.5-flash-lite"
self.s2t_model = "gemini-2.5-flash-lite"
# --- VAD Initialization ---
self.vad = webrtcvad.Vad(3)
self.VAD_RATE = 16000
self.VAD_FRAME_MS = 20
self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0))
self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2
padding_ms = 300
self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS
self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames)
self.vad_ratio = 0.9
self.vad_triggered = False
self.wav_data = bytearray()
self.internal_buffer = bytearray()
self.end_of_speech_time: float | None = None
self.first_latency_calculated: bool = False
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
prompt_dict=self.prompt_dict,
)
def s2t(self, audio) -> str:
response = self.s2t_client.models.generate_content(
model=self.s2t_model,
contents=[
types.Part.from_bytes(data=audio, mime_type='audio/wav'),
'Generate a transcript of the speech.'
]
)
return response.text
def embed_texts(self, texts: list[str], batch_size: int = 50) -> list[list[float]]:
"""Embed a list of texts using the configured OpenAI/DeepInfra client.
Returns a list of embedding vectors (or empty lists on failure for each item).
"""
all_embeddings: list[list[float]] = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
resp = openai.embeddings.create(
model="Qwen/Qwen3-Embedding-8B",
input=batch,
encoding_format="float"
)
batch_embs = [item.embedding for item in resp.data]
all_embeddings.extend(batch_embs)
except Exception as e:
print(f"Embedding batch error (items {i}{i+len(batch)-1}): {e}")
all_embeddings.extend([[] for _ in batch])
return all_embeddings
def s2t_and_embed(self, audio) -> list[float]:
"""Convert speech to text, then embed the transcript."""
transcript = self.s2t(audio) # Step 1: Speech → Text
if not transcript:
return []
embeddings = self.embed_texts([transcript]) # Step 2: Text → Embedding
return embeddings[0] if embeddings else []
def encode_query(self, query: str) -> list[float] | None:
"""Generate a single embedding vector for a query string."""
embs = self.embed_texts([query], batch_size=1)
if embs and embs[0]:
print("Query embedding (first 5 dims):", embs[0][:5])
return embs[0]
print("Failed to generate query embedding.")
return None
def rag_autism(self, query: str, top_k: int = 3) -> dict:
"""
Run a RAG retrieval on the 'UserDocument' collection in Weaviate using v4 syntax.
Returns up to `top_k` matching text chunks as {'answer': [texts...]}
"""
qe = self.encode_query(query)
if not qe:
return {"answer": []}
try:
with weaviate_client() as client:
books_collection = client.collections.get("UserDocument")
response = books_collection.query.near_vector(
near_vector=qe,
limit=top_k,
return_properties=["text"]
)
# Extract the text property from each object
hits = [obj.properties.get("text") for obj in response.objects if "text" in obj.properties]
# --- FIX: REMOVE REPEATED CONTEXT ---
# Convert to a dictionary's keys to get unique items, then back to a list
unique_hits = list(dict.fromkeys(hits))
if not unique_hits:
return {"answer": []}
return {"answer": unique_hits}
except Exception as e:
print("RAG Error:", e)
return {"answer": []}
def t2t(self, text: str) -> str:
"""
Sends text to the pre-initialized chat model and returns the text response.
"""
try:
# Ensure the chat session exists before using it.
if not hasattr(self, 'chat'):
print("Error: Chat session (self.chat) is not initialized.")
return "I'm sorry, my chat function is not ready."
# Use the existing chat session to send the message.
print("--> Attempting to send prompt to t2t model...")
response = self.chat.send_message(text)
print("--> Successfully received response from t2t model.")
return response.text
except Exception as e:
print(f"t2t error: {e}")
return ""
async def start_up(self):
# Flag for if we are using text-to-text in the middle of the chain or not.
self.t2t_bool = False
self.sys_prompt = None
self.t2t_client = genai.Client(api_key=GEMINI_API_KEY)
self.s2t_client = genai.Client(api_key=GEMINI_API_KEY)
if self.sys_prompt is not None:
chat_config = types.GenerateContentConfig(system_instruction=self.sys_prompt)
else:
chat_config = types.GenerateContentConfig(system_instruction="You are a helpful assistant.")
self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config)
self.t2s_client = genai.Client(api_key=GEMINI_API_KEY)
voice_name = "Puck"
if self.t2t_bool:
sys_instruction = f""" You are Wisal, an AI assistant developed by Compumacy AI , and a knowledgeable Autism .
Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism Spectrum Disorder (ASD).
Always be clear, non-judgmental, and supportive."""
else:
sys_instruction = self.sys_prompt
if sys_instruction is not None:
config = LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
)
),
system_instruction=Content(parts=[Part.from_text(text=sys_instruction)])
)
else:
config = LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
)
),
)
async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session:
async for text_from_user in self.stream():
print("--------------------------------------------")
print(f"Received text from user and reading aloud: {text_from_user}")
print("--------------------------------------------")
if not text_from_user or not text_from_user.strip():
continue
# 1) Run RAG retrieval on the user input to get contextual snippets
try:
rag_res = self.rag_autism(text_from_user, top_k=3)
context_snippets = rag_res.get("answer", []) if isinstance(rag_res, dict) else []
# --- ADDED THIS BLOCK TO PRINT THE RAG CONTEXT ---
if context_snippets:
print("\n--- RAG CONTEXT RETRIEVED ---")
for i, snippet in enumerate(context_snippets):
print(f"Snippet {i+1}: {snippet}...")
print("-----------------------------\n")
#
except Exception as e:
print("Error running RAG:", e)
context_snippets = []
# 2) Build the prompt for t2t model including retrieved context
combined_context = "\n\n".join(context_snippets) if context_snippets else ""
if combined_context:
prompt =(
"Please answer the user's question based on the following context. "
"Be helpful and concise.\n\n"
f"--- CONTEXT ---\n{combined_context}\n\n"
f"--- USER QUESTION ---\n{text_from_user}"
)
else:
prompt = (
"Answer the user's question from your own knowledge as a helpful assistant "
"specializing in Autism Spectrum Disorder.\n\n"
f"--- USER QUESTION ---\n{text_from_user}"
)
print(prompt)
# 3) Send prompt to chat (t2t) to obtain reply text
try:
reply_text = self.t2t(prompt)
print("\n--- FINAL AI RESPONSE ---")
print(reply_text)
print("-----------------------------")
except Exception as e:
print("t2t generation error:", e)
reply_text = ""
if not reply_text:
print("No t2t reply generated, skipping t2s send.")
continue
# 4) Send the reply_text to the live TTS session to speak it
try:
text_to_speak = f"Read the following text aloud exactly as it is, without adding or changing anything: '{reply_text}'"
print(f">>> MODIFIED TEXT SENT TO T2S API: '{text_to_speak}' <<<")
await session.send_client_content(
turns=types.Content(role='user', parts=[types.Part(text=text_to_speak)])
)
async for resp_chunk in session.receive():
if getattr(resp_chunk, "data", None):
array = np.frombuffer(resp_chunk.data, dtype=np.int16)
self.output_queue.put_nowait((self.output_sample_rate, array))
except Exception as e:
print("Error sending to live TTS session:", e)
async def stream(self) -> AsyncGenerator[bytes, None]:
while not self.quit.is_set():
try:
# Get the text message to be converted to speech
text_to_speak = await self.input_queue.get()
yield text_to_speak
except (asyncio.TimeoutError, TimeoutError):
pass
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
sr, array = frame
audio_bytes = array.tobytes()
self.internal_buffer.extend(audio_bytes)
while len(self.internal_buffer) >= self.VAD_FRAME_BYTES:
vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES]
self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:]
is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE)
if not self.vad_triggered:
self.vad_ring_buffer.append((vad_frame, is_speech))
num_voiced = len([f for f, speech in self.vad_ring_buffer if speech])
if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
print("Speech detected, starting to record...")
self.vad_triggered = True
for f, s in self.vad_ring_buffer:
self.wav_data.extend(f)
self.vad_ring_buffer.clear()
else:
self.wav_data.extend(vad_frame)
self.vad_ring_buffer.append((vad_frame, is_speech))
num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech])
if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
print("End of speech detected.")
self.end_of_speech_time = time.monotonic()
self.vad_triggered = False
full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16)
audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr)
text_input = self.s2t(audio_input_wav)
# --- ADDED THIS BLOCK TO PRINT THE S2T TRANSCRIPT ---
print("\n--- FULL S2T TRANSCRIPT ---")
print(f"'{text_input}'")
print("---------------------------\n")
# ----------------------------------------------------
if text_input and text_input.strip():
if self.t2t_bool:
text_message = self.t2t(text_input)
else:
text_message = text_input
self.input_queue.put_nowait(text_message)
else:
print("STT returned empty transcript, skipping.")
self.vad_ring_buffer.clear()
self.wav_data = bytearray()
async def emit(self) -> tuple[int, np.ndarray] | None:
return await wait_for_item(self.output_queue)
def shutdown(self) -> None:
self.quit.set()
with gr.Blocks() as demo:
gr.Markdown("# Gemini Chained Speech-to-Speech Demo")
with gr.Row() as row2:
with gr.Column():
webrtc2 = WebRTC(
label="Audio Chat",
modality="audio",
mode="send-receive",
elem_id="audio-source",
rtc_configuration=safe_get_ice_config_async,
icon="https://www.gstatic.com/lamda/images/gemini_favicon_f069958c85030456e93de685481c559f160ea06b.png",
pulse_color="rgb(255, 255, 255)",
icon_button_color="rgb(255, 255, 255)",
)
webrtc2.stream(
GeminiHandler(),
inputs=[webrtc2],
outputs=[webrtc2],
time_limit=180 if get_space() else None,
concurrency_limit=2 if get_space() else None,
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT",7860)),
debug=True,
)