websokectstts / app.py
1MR's picture
Create app.py
9268011 verified
import os
import tempfile
import asyncio
from fastapi import FastAPI, WebSocket
from groq import Groq
from langchain_google_genai import ChatGoogleGenerativeAI
os.environ["GOOGLE_API_KEY"] = "AIzaSyD2DMFgcL0kWTQYhii8wseSHY3BRGWSebk"
client = Groq(api_key="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")
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")
app = FastAPI(title="Realtime STT → Gemini → TTS")
@app.websocket("/ws/stream")
async def websocket_stt_tts(ws: WebSocket):
await ws.accept()
buffer = b""
try:
while True:
message = await ws.receive()
# Receive audio chunks
if "bytes" in message:
buffer += message["bytes"]
# process small buffer if large enough
if len(buffer) > 16000: # ~1 sec @16kHz
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(buffer)
tmp_path = tmp.name
buffer = b""
# STT partial transcription
with open(tmp_path, "rb") as f:
transcription = client.audio.transcriptions.create(
file=(tmp_path, f.read()),
model="whisper-large-v3-turbo",
response_format="verbose_json",
)
text_chunk = transcription.text.strip()
if text_chunk:
await ws.send_text(f"PARTIAL_TRANSCRIPT: {text_chunk}")
# Gemini streaming response
stream = llm.stream(text_chunk)
response_accum = ""
async for event in stream:
if event.content:
token = event.content
response_accum += token
await ws.send_text(f"AI_TOKEN: {token}")
# TTS stream for the accumulated response
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
response = client.audio.speech.create(
model="playai-tts",
voice="Atlas-PlayAI",
response_format="wav",
input=response_accum,
)
response.write_to_file(out_path)
# Stream audio back
with open(out_path, "rb") as f:
chunk = f.read(4096)
while chunk:
await ws.send_bytes(chunk)
await asyncio.sleep(0.01)
chunk = f.read(4096)
await ws.send_text("TTS_DONE")
elif "text" in message and message["text"] == "CLOSE":
await ws.close()
break
except Exception as e:
await ws.send_text(f"ERROR: {str(e)}")
await ws.close()