File size: 6,680 Bytes
fc77df5
 
0a9dfed
 
fc77df5
0a9dfed
fc77df5
9c5a6e7
3b000e9
0a9dfed
dccec3c
0a9dfed
fc77df5
 
 
 
 
 
 
3e9ae46
3b000e9
9c5a6e7
dccec3c
fc77df5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a9dfed
fc77df5
 
 
 
9c5a6e7
fc77df5
 
 
 
 
 
 
0a9dfed
fc77df5
0a9dfed
fc77df5
 
 
 
9c5a6e7
 
fc77df5
9c5a6e7
 
fc77df5
9c5a6e7
3e9ae46
fc77df5
 
dccec3c
0a9dfed
fc77df5
 
 
176ee90
3b000e9
 
fc77df5
 
 
 
 
 
 
 
 
 
 
 
0a9dfed
 
 
fc77df5
0a9dfed
 
dccec3c
0a9dfed
fc77df5
 
0a9dfed
 
 
 
fc77df5
 
 
0a9dfed
fc77df5
dccec3c
fc77df5
 
 
 
 
 
 
dccec3c
fc77df5
dccec3c
fc77df5
 
 
 
 
 
 
 
 
0a9dfed
 
fc77df5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a9dfed
fc77df5
 
 
 
 
 
 
 
0a9dfed
fc77df5
 
 
 
0a9dfed
 
fc77df5
 
 
 
 
 
 
 
 
 
0a9dfed
 
fc77df5
0a9dfed
 
fc77df5
 
0a9dfed
fc77df5
0a9dfed
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from __future__ import annotations

import asyncio
import json
import os
import uuid
from typing import Any, Dict, Optional

import numpy as np
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse

from fastrtc import Stream, ReplyOnPause, get_stt_model, get_tts_model

from .gemini_text import (
    gemini_chat_turn,
    get_session,
    deliver_function_result,
)

app = FastAPI()


# ----------------------------
# FastRTC Voice Chat (VAD + STT + TTS)
# ----------------------------

# These are CPU-friendly, but still heavy on Spaces. Keep them global.
STT_MODEL_NAME = os.getenv("FASTRTC_STT_MODEL", "moonshine/tiny")
TTS_MODEL_NAME = os.getenv("FASTRTC_TTS_MODEL", "kokoro")

stt = get_stt_model(model=STT_MODEL_NAME)
tts = get_tts_model(model=TTS_MODEL_NAME)


def _voice_reply_fn(audio: tuple[int, np.ndarray]):
    """
    Called when the user pauses (VAD). Returns streamed audio frames (TTS).
    """
    # audio is (sample_rate, int16 mono ndarray)
    # FastRTC STT expects "audio" in the same tuple form per docs examples.
    user_text = stt.stt(audio).strip()
    if not user_text:
        return

    # For voice sessions we create a synthetic session_id (not Scratch ws session)
    # because FastRTC’s ReplyOnPause fn signature doesn’t expose the RTC session id.
    # This keeps a stable conversation state per-process, but not per-user.
    #
    # If you need per-user memory for voice, we can switch to a stateful StreamHandler later.
    voice_session_id = "voice-global"

    async def run():
        # No tool bounce for voice by default (still supported via same session registry if you want)
        async def noop_emit(_evt: dict):
            return

        text = await gemini_chat_turn(
            session_id=voice_session_id,
            user_text=user_text,
            emit_event=noop_emit,
            model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
        )
        return text

    text = asyncio.get_event_loop().run_until_complete(run())

    # Stream TTS back
    for chunk in tts.stream_tts_sync(text):
        # chunk is already an audio frame compatible with FastRTC
        yield chunk


voice_stream = Stream(
    modality="audio",
    mode="send-receive",
    handler=ReplyOnPause(_voice_reply_fn),
)

# Mount FastRTC endpoints (WebRTC + WebSocket) under /rtc
voice_stream.mount(app, path="/rtc")


# ----------------------------
# Scratch-friendly WebSocket API (text + function calling)
# ----------------------------

@app.get("/")
async def root():
    return JSONResponse(
        {
            "ok": True,
            "service": "salexai-api",
            "ws": "/ws",
            "fastrtc": "/rtc",
            "notes": [
                "Use /ws for Scratch JSON chat + function calling.",
                "Use /rtc for FastRTC voice chat endpoints (VAD/STT/TTS handled by FastRTC).",
            ],
        }
    )


@app.websocket("/ws")
async def ws_endpoint(ws: WebSocket):
    await ws.accept()

    session_id: Optional[str] = None

    async def emit(evt: dict):
        await ws.send_text(json.dumps(evt))

    try:
        while True:
            raw = await ws.receive_text()
            msg = json.loads(raw) if raw else {}

            mtype = msg.get("type")

            if mtype == "connect":
                session_id = msg.get("session_id") or str(uuid.uuid4())
                get_session(session_id)  # ensure exists
                await emit({"type": "ready", "session_id": session_id})
                continue

            if not session_id:
                await emit({"type": "error", "message": "Not connected. Send {type:'connect'} first."})
                continue

            # -------- function registry --------

            if mtype == "add_function":
                name = str(msg.get("name") or "").strip()
                schema = msg.get("schema") or {}
                if not name:
                    await emit({"type": "error", "message": "add_function missing name"})
                    continue
                s = get_session(session_id)
                s.functions[name] = schema
                await emit({"type": "function_added", "name": name})
                continue

            if mtype == "remove_function":
                name = str(msg.get("name") or "").strip()
                s = get_session(session_id)
                if name in s.functions:
                    s.functions.pop(name, None)
                    await emit({"type": "function_removed", "name": name})
                else:
                    await emit({"type": "warning", "message": f"Function not found: {name}"})
                continue

            if mtype == "list_functions":
                s = get_session(session_id)
                await emit({"type": "functions", "items": list(s.functions.keys())})
                continue

            # Client returns tool results
            if mtype == "function_result":
                call_id = msg.get("call_id")
                result = msg.get("result")
                if not call_id:
                    await emit({"type": "error", "message": "function_result missing call_id"})
                    continue
                ok = deliver_function_result(session_id, call_id, result)
                if not ok:
                    await emit({"type": "warning", "message": f"No pending call_id: {call_id}"})
                else:
                    await emit({"type": "function_result_ack", "call_id": call_id})
                continue

            # -------- chat --------

            if mtype == "send":
                text = str(msg.get("text") or "")
                if not text.strip():
                    await emit({"type": "error", "message": "Empty text"})
                    continue

                try:
                    assistant_text = await gemini_chat_turn(
                        session_id=session_id,
                        user_text=text,
                        emit_event=emit,  # this is where tool calls get emitted
                        model=os.getenv("GEMINI_TEXT_MODEL", "gemini-2.0-flash"),
                    )
                    await emit({"type": "assistant", "text": assistant_text})
                except Exception as e:
                    await emit({"type": "error", "message": f"Gemini error: {e}"})
                continue

            await emit({"type": "error", "message": f"Unknown type: {mtype}"})

    except WebSocketDisconnect:
        return
    except Exception as e:
        try:
            await emit({"type": "error", "message": f"WS crashed: {e}"})
        except Exception:
            pass