File size: 7,037 Bytes
226ff5d
 
6eb38ac
226ff5d
38a5904
4be6d55
aebb7d4
 
e75bac4
 
 
 
4be6d55
e75bac4
38a5904
 
4be6d55
e75bac4
 
226ff5d
6eb38ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226ff5d
 
e75bac4
226ff5d
 
 
4be6d55
 
 
b35a210
4be6d55
 
 
 
226ff5d
986403e
 
 
 
 
 
 
 
226ff5d
 
 
 
e75bac4
986403e
226ff5d
 
 
 
 
 
e75bac4
38a5904
226ff5d
 
e75bac4
226ff5d
 
 
 
aebb7d4
 
38a5904
aebb7d4
 
 
38a5904
 
4be6d55
38a5904
aebb7d4
 
 
 
38a5904
4be6d55
 
 
 
 
38a5904
 
 
 
 
 
aebb7d4
 
 
 
 
 
 
c2e783d
 
e75bac4
 
 
 
 
 
 
 
 
 
aebb7d4
e75bac4
aebb7d4
 
e75bac4
aebb7d4
 
 
 
 
 
226ff5d
e75bac4
 
38a5904
 
 
e75bac4
226ff5d
38a5904
986403e
 
38a5904
226ff5d
 
4be6d55
 
 
 
226ff5d
 
4be6d55
 
 
 
226ff5d
38a5904
 
 
 
 
 
 
 
 
 
e75bac4
 
 
38a5904
 
 
e75bac4
226ff5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986403e
 
 
 
226ff5d
 
 
 
 
986403e
 
 
 
226ff5d
 
 
6eb38ac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import json
import logging
import logging.config
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, Form, UploadFile, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from src.pipeline import VoicePipeline
from typing import Literal
from src.config import (
    DEEPGRAM_API_KEY, CARTESIA_API_KEY, CARTESIA_VOICE_ID,
    SAMPLE_RATE, GOOGLE_API_KEY, GOOGLE_PROJECT_ID, STT_PROVIDER, TTS_PROVIDER, WAKE_WORD_ENABLED,
)
from src.stt.deepgram_rest import transcribe_audio as deepgram_transcribe
from src.stt.gemini_stt import transcribe_audio as gemini_stt_transcribe
from src.stt.chirp3_client import transcribe_audio as chirp3_transcribe
from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize
from src.tts.gemini_client import synthesize_stream as gemini_synthesize, GEMINI_SAMPLE_RATE

LOG_CONFIG = {
    "version": 1,
    "disable_existing_loggers": False,
    "formatters": {
        "default": {
            "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
        },
    },
    "handlers": {
        "default": {
            "class": "logging.StreamHandler",
            "formatter": "default",
        },
    },
    "root": {
        "level": "INFO",
        "handlers": ["default"],
    },
}
logging.config.dictConfig(LOG_CONFIG)
logger = logging.getLogger(__name__)

VERSION = "1.2.0"

app = FastAPI(title="Voice Agent Service", version=VERSION)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def health() -> JSONResponse:
    body: dict = {
        "message": "Welcome to the Voice Agent Service! Please use the /health endpoint to check service status.",
        "version": VERSION,
    }
    return JSONResponse(status_code=200, content=body)

@app.get("/health")
async def health() -> JSONResponse:
    stt_ready = bool(DEEPGRAM_API_KEY)
    tts_ready = bool(CARTESIA_API_KEY and CARTESIA_VOICE_ID)
    gemini_ready = bool(GOOGLE_API_KEY)
    all_ready = stt_ready and tts_ready

    body: dict = {
        "status": "ok" if all_ready else "degraded",
        "version": VERSION,
        "stt_ready": stt_ready,
        "tts_ready": tts_ready,
        "gemini_tts_ready": gemini_ready,
        "gemini_stt_ready": gemini_ready,
    }
    if not all_ready:
        body["message"] = "One or more required configurations are missing. Check your .env file."

    return JSONResponse(status_code=200 if all_ready else 503, content=body)


class TTSRequest(BaseModel):
    text: str
    provider: Literal["cartesia", "gemini"] = "gemini"


@app.post("/stt")
async def speech_to_text(
    audio: UploadFile = File(...),
    provider: str = Form(default="chirp3"),
) -> JSONResponse:
    data = await audio.read()
    if not data:
        raise HTTPException(status_code=400, detail="Audio file is empty.")
    mimetype = audio.content_type or "audio/wav"

    if provider == "chirp3":
        if not GOOGLE_PROJECT_ID:
            raise HTTPException(status_code=503, detail="Chirp3 STT not configured: missing GOOGLE_PROJECT_ID.")
        result = await chirp3_transcribe(data, mimetype=mimetype)
    elif provider == "gemini":
        if not GOOGLE_API_KEY:
            raise HTTPException(status_code=503, detail="Gemini STT not configured: missing GOOGLE_API_KEY.")
        result = await gemini_stt_transcribe(data, mimetype=mimetype)
    else:
        result = await deepgram_transcribe(data, mimetype=mimetype)

    return JSONResponse(content=result)


@app.post("/tts")
async def text_to_speech(req: TTSRequest) -> StreamingResponse:
    if not req.text.strip():
        raise HTTPException(status_code=400, detail="text must not be empty.")

    logger.info("TTS request: provider=%s, text_len=%d, text=%r", req.provider, len(req.text), req.text)

    if req.provider == "gemini":
        if not GOOGLE_API_KEY:
            raise HTTPException(status_code=503, detail="Gemini TTS not configured.")
        stream = gemini_synthesize(req.text)
        sample_rate = GEMINI_SAMPLE_RATE
    else:
        stream = cartesia_synthesize(req.text)
        sample_rate = SAMPLE_RATE

    return StreamingResponse(
        stream,
        media_type="audio/pcm",
        headers={
            "X-Sample-Rate": str(sample_rate),
            "X-Encoding": "pcm_s16le",
            "X-Channels": "1",
        },
    )


@app.websocket("/ws/voice")
async def voice_ws(
    ws: WebSocket,
    stt_provider: str = Query(default=STT_PROVIDER),
    tts_provider: str = Query(default=TTS_PROVIDER),
    wake_word_enabled: bool = Query(default=WAKE_WORD_ENABLED),
) -> None:
    await ws.accept()
    logger.info(
        "Client connected: %s (stt=%s, tts=%s)",
        ws.client, stt_provider, tts_provider,
    )

    async def send_audio(chunk: bytes) -> None:
        try:
            await ws.send_bytes(chunk)
        except WebSocketDisconnect:
            pass

    async def send_event(event: dict) -> None:
        try:
            await ws.send_text(json.dumps(event))
        except WebSocketDisconnect:
            pass

    tts_sample_rate = GEMINI_SAMPLE_RATE if tts_provider == "gemini" else SAMPLE_RATE
    await send_event({
        "event": "tts_config",
        "tts_provider": tts_provider,
        "stt_provider": stt_provider,
        "sample_rate": tts_sample_rate,
        "encoding": "pcm_s16le",
        "channels": 1,
    })

    pipeline = VoicePipeline(
        send_audio=send_audio,
        send_event=send_event,
        stt_provider=stt_provider,
        tts_provider=tts_provider,
        wake_word_enabled=wake_word_enabled,
    )
    pipeline.start()

    try:
        while True:
            data = await ws.receive()
            if "bytes" in data and data["bytes"]:
                pipeline.feed_audio(data["bytes"])
            elif "text" in data and data["text"]:
                try:
                    msg = json.loads(data["text"])
                    action = msg.get("action")
                    if action == "stop":
                        break
                    elif action == "ping":
                        await ws.send_text(json.dumps({"event": "pong"}))
                    elif action == "interrupt":
                        await pipeline.interrupt()
                    elif action == "speak":
                        text = msg.get("text", "").strip()
                        if text:
                            await pipeline.speak(text)
                except json.JSONDecodeError:
                    pass
    except WebSocketDisconnect:
        logger.info("Client disconnected: %s", ws.client)
    finally:
        try:
            await pipeline.stop_async()
        except Exception:
            logger.exception("Error during pipeline stop")


if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False, log_config=LOG_CONFIG)