Spaces:
Sleeping
Sleeping
| import os | |
| import tempfile | |
| import asyncio | |
| from fastapi import FastAPI, WebSocket | |
| from groq import Groq | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| os.environ["GOOGLE_API_KEY"] = "AIzaSyD2DMFgcL0kWTQYhii8wseSHY3BRGWSebk" | |
| client = Groq(api_key="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") | |
| llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash") | |
| app = FastAPI(title="Realtime STT → Gemini → TTS") | |
| async def websocket_stt_tts(ws: WebSocket): | |
| await ws.accept() | |
| buffer = b"" | |
| try: | |
| while True: | |
| message = await ws.receive() | |
| # Receive audio chunks | |
| if "bytes" in message: | |
| buffer += message["bytes"] | |
| # process small buffer if large enough | |
| if len(buffer) > 16000: # ~1 sec @16kHz | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
| tmp.write(buffer) | |
| tmp_path = tmp.name | |
| buffer = b"" | |
| # STT partial transcription | |
| with open(tmp_path, "rb") as f: | |
| transcription = client.audio.transcriptions.create( | |
| file=(tmp_path, f.read()), | |
| model="whisper-large-v3-turbo", | |
| response_format="verbose_json", | |
| ) | |
| text_chunk = transcription.text.strip() | |
| if text_chunk: | |
| await ws.send_text(f"PARTIAL_TRANSCRIPT: {text_chunk}") | |
| # Gemini streaming response | |
| stream = llm.stream(text_chunk) | |
| response_accum = "" | |
| async for event in stream: | |
| if event.content: | |
| token = event.content | |
| response_accum += token | |
| await ws.send_text(f"AI_TOKEN: {token}") | |
| # TTS stream for the accumulated response | |
| out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name | |
| response = client.audio.speech.create( | |
| model="playai-tts", | |
| voice="Atlas-PlayAI", | |
| response_format="wav", | |
| input=response_accum, | |
| ) | |
| response.write_to_file(out_path) | |
| # Stream audio back | |
| with open(out_path, "rb") as f: | |
| chunk = f.read(4096) | |
| while chunk: | |
| await ws.send_bytes(chunk) | |
| await asyncio.sleep(0.01) | |
| chunk = f.read(4096) | |
| await ws.send_text("TTS_DONE") | |
| elif "text" in message and message["text"] == "CLOSE": | |
| await ws.close() | |
| break | |
| except Exception as e: | |
| await ws.send_text(f"ERROR: {str(e)}") | |
| await ws.close() | |