File size: 3,187 Bytes
9268011
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import tempfile
import asyncio
from fastapi import FastAPI, WebSocket
from groq import Groq
from langchain_google_genai import ChatGoogleGenerativeAI

os.environ["GOOGLE_API_KEY"] = "AIzaSyD2DMFgcL0kWTQYhii8wseSHY3BRGWSebk"
client = Groq(api_key="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")

llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash")

app = FastAPI(title="Realtime STT → Gemini → TTS")

@app.websocket("/ws/stream")
async def websocket_stt_tts(ws: WebSocket):
    await ws.accept()
    buffer = b""

    try:
        while True:
            message = await ws.receive()

            # Receive audio chunks
            if "bytes" in message:
                buffer += message["bytes"]

                # process small buffer if large enough
                if len(buffer) > 16000:  # ~1 sec @16kHz
                    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
                        tmp.write(buffer)
                        tmp_path = tmp.name
                        buffer = b""

                    # STT partial transcription
                    with open(tmp_path, "rb") as f:
                        transcription = client.audio.transcriptions.create(
                            file=(tmp_path, f.read()),
                            model="whisper-large-v3-turbo",
                            response_format="verbose_json",
                        )
                    text_chunk = transcription.text.strip()
                    if text_chunk:
                        await ws.send_text(f"PARTIAL_TRANSCRIPT: {text_chunk}")

                        # Gemini streaming response
                        stream = llm.stream(text_chunk)
                        response_accum = ""
                        async for event in stream:
                            if event.content:
                                token = event.content
                                response_accum += token
                                await ws.send_text(f"AI_TOKEN: {token}")

                        # TTS stream for the accumulated response
                        out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
                        response = client.audio.speech.create(
                            model="playai-tts",
                            voice="Atlas-PlayAI",
                            response_format="wav",
                            input=response_accum,
                        )
                        response.write_to_file(out_path)

                        # Stream audio back
                        with open(out_path, "rb") as f:
                            chunk = f.read(4096)
                            while chunk:
                                await ws.send_bytes(chunk)
                                await asyncio.sleep(0.01)
                                chunk = f.read(4096)

                        await ws.send_text("TTS_DONE")

            elif "text" in message and message["text"] == "CLOSE":
                await ws.close()
                break

    except Exception as e:
        await ws.send_text(f"ERROR: {str(e)}")
        await ws.close()