import os import tempfile import asyncio from fastapi import FastAPI, WebSocket from groq import Groq from langchain_google_genai import ChatGoogleGenerativeAI os.environ["GOOGLE_API_KEY"] = "AIzaSyD2DMFgcL0kWTQYhii8wseSHY3BRGWSebk" client = Groq(api_key="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") llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash") app = FastAPI(title="Realtime STT → Gemini → TTS") @app.websocket("/ws/stream") async def websocket_stt_tts(ws: WebSocket): await ws.accept() buffer = b"" try: while True: message = await ws.receive() # Receive audio chunks if "bytes" in message: buffer += message["bytes"] # process small buffer if large enough if len(buffer) > 16000: # ~1 sec @16kHz with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: tmp.write(buffer) tmp_path = tmp.name buffer = b"" # STT partial transcription with open(tmp_path, "rb") as f: transcription = client.audio.transcriptions.create( file=(tmp_path, f.read()), model="whisper-large-v3-turbo", response_format="verbose_json", ) text_chunk = transcription.text.strip() if text_chunk: await ws.send_text(f"PARTIAL_TRANSCRIPT: {text_chunk}") # Gemini streaming response stream = llm.stream(text_chunk) response_accum = "" async for event in stream: if event.content: token = event.content response_accum += token await ws.send_text(f"AI_TOKEN: {token}") # TTS stream for the accumulated response out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name response = client.audio.speech.create( model="playai-tts", voice="Atlas-PlayAI", response_format="wav", input=response_accum, ) response.write_to_file(out_path) # Stream audio back with open(out_path, "rb") as f: chunk = f.read(4096) while chunk: await ws.send_bytes(chunk) await asyncio.sleep(0.01) chunk = f.read(4096) await ws.send_text("TTS_DONE") elif "text" in message and message["text"] == "CLOSE": await ws.close() break except Exception as e: await ws.send_text(f"ERROR: {str(e)}") await ws.close()