Spaces:
Sleeping
Sleeping
File size: 7,037 Bytes
226ff5d 6eb38ac 226ff5d 38a5904 4be6d55 aebb7d4 e75bac4 4be6d55 e75bac4 38a5904 4be6d55 e75bac4 226ff5d 6eb38ac 226ff5d e75bac4 226ff5d 4be6d55 b35a210 4be6d55 226ff5d 986403e 226ff5d e75bac4 986403e 226ff5d e75bac4 38a5904 226ff5d e75bac4 226ff5d aebb7d4 38a5904 aebb7d4 38a5904 4be6d55 38a5904 aebb7d4 38a5904 4be6d55 38a5904 aebb7d4 c2e783d e75bac4 aebb7d4 e75bac4 aebb7d4 e75bac4 aebb7d4 226ff5d e75bac4 38a5904 e75bac4 226ff5d 38a5904 986403e 38a5904 226ff5d 4be6d55 226ff5d 4be6d55 226ff5d 38a5904 e75bac4 38a5904 e75bac4 226ff5d 986403e 226ff5d 986403e 226ff5d 6eb38ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | import json
import logging
import logging.config
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, Form, UploadFile, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from src.pipeline import VoicePipeline
from typing import Literal
from src.config import (
DEEPGRAM_API_KEY, CARTESIA_API_KEY, CARTESIA_VOICE_ID,
SAMPLE_RATE, GOOGLE_API_KEY, GOOGLE_PROJECT_ID, STT_PROVIDER, TTS_PROVIDER, WAKE_WORD_ENABLED,
)
from src.stt.deepgram_rest import transcribe_audio as deepgram_transcribe
from src.stt.gemini_stt import transcribe_audio as gemini_stt_transcribe
from src.stt.chirp3_client import transcribe_audio as chirp3_transcribe
from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize
from src.tts.gemini_client import synthesize_stream as gemini_synthesize, GEMINI_SAMPLE_RATE
LOG_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
},
},
"handlers": {
"default": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"root": {
"level": "INFO",
"handlers": ["default"],
},
}
logging.config.dictConfig(LOG_CONFIG)
logger = logging.getLogger(__name__)
VERSION = "1.2.0"
app = FastAPI(title="Voice Agent Service", version=VERSION)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health() -> JSONResponse:
body: dict = {
"message": "Welcome to the Voice Agent Service! Please use the /health endpoint to check service status.",
"version": VERSION,
}
return JSONResponse(status_code=200, content=body)
@app.get("/health")
async def health() -> JSONResponse:
stt_ready = bool(DEEPGRAM_API_KEY)
tts_ready = bool(CARTESIA_API_KEY and CARTESIA_VOICE_ID)
gemini_ready = bool(GOOGLE_API_KEY)
all_ready = stt_ready and tts_ready
body: dict = {
"status": "ok" if all_ready else "degraded",
"version": VERSION,
"stt_ready": stt_ready,
"tts_ready": tts_ready,
"gemini_tts_ready": gemini_ready,
"gemini_stt_ready": gemini_ready,
}
if not all_ready:
body["message"] = "One or more required configurations are missing. Check your .env file."
return JSONResponse(status_code=200 if all_ready else 503, content=body)
class TTSRequest(BaseModel):
text: str
provider: Literal["cartesia", "gemini"] = "gemini"
@app.post("/stt")
async def speech_to_text(
audio: UploadFile = File(...),
provider: str = Form(default="chirp3"),
) -> JSONResponse:
data = await audio.read()
if not data:
raise HTTPException(status_code=400, detail="Audio file is empty.")
mimetype = audio.content_type or "audio/wav"
if provider == "chirp3":
if not GOOGLE_PROJECT_ID:
raise HTTPException(status_code=503, detail="Chirp3 STT not configured: missing GOOGLE_PROJECT_ID.")
result = await chirp3_transcribe(data, mimetype=mimetype)
elif provider == "gemini":
if not GOOGLE_API_KEY:
raise HTTPException(status_code=503, detail="Gemini STT not configured: missing GOOGLE_API_KEY.")
result = await gemini_stt_transcribe(data, mimetype=mimetype)
else:
result = await deepgram_transcribe(data, mimetype=mimetype)
return JSONResponse(content=result)
@app.post("/tts")
async def text_to_speech(req: TTSRequest) -> StreamingResponse:
if not req.text.strip():
raise HTTPException(status_code=400, detail="text must not be empty.")
logger.info("TTS request: provider=%s, text_len=%d, text=%r", req.provider, len(req.text), req.text)
if req.provider == "gemini":
if not GOOGLE_API_KEY:
raise HTTPException(status_code=503, detail="Gemini TTS not configured.")
stream = gemini_synthesize(req.text)
sample_rate = GEMINI_SAMPLE_RATE
else:
stream = cartesia_synthesize(req.text)
sample_rate = SAMPLE_RATE
return StreamingResponse(
stream,
media_type="audio/pcm",
headers={
"X-Sample-Rate": str(sample_rate),
"X-Encoding": "pcm_s16le",
"X-Channels": "1",
},
)
@app.websocket("/ws/voice")
async def voice_ws(
ws: WebSocket,
stt_provider: str = Query(default=STT_PROVIDER),
tts_provider: str = Query(default=TTS_PROVIDER),
wake_word_enabled: bool = Query(default=WAKE_WORD_ENABLED),
) -> None:
await ws.accept()
logger.info(
"Client connected: %s (stt=%s, tts=%s)",
ws.client, stt_provider, tts_provider,
)
async def send_audio(chunk: bytes) -> None:
try:
await ws.send_bytes(chunk)
except WebSocketDisconnect:
pass
async def send_event(event: dict) -> None:
try:
await ws.send_text(json.dumps(event))
except WebSocketDisconnect:
pass
tts_sample_rate = GEMINI_SAMPLE_RATE if tts_provider == "gemini" else SAMPLE_RATE
await send_event({
"event": "tts_config",
"tts_provider": tts_provider,
"stt_provider": stt_provider,
"sample_rate": tts_sample_rate,
"encoding": "pcm_s16le",
"channels": 1,
})
pipeline = VoicePipeline(
send_audio=send_audio,
send_event=send_event,
stt_provider=stt_provider,
tts_provider=tts_provider,
wake_word_enabled=wake_word_enabled,
)
pipeline.start()
try:
while True:
data = await ws.receive()
if "bytes" in data and data["bytes"]:
pipeline.feed_audio(data["bytes"])
elif "text" in data and data["text"]:
try:
msg = json.loads(data["text"])
action = msg.get("action")
if action == "stop":
break
elif action == "ping":
await ws.send_text(json.dumps({"event": "pong"}))
elif action == "interrupt":
await pipeline.interrupt()
elif action == "speak":
text = msg.get("text", "").strip()
if text:
await pipeline.speak(text)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
logger.info("Client disconnected: %s", ws.client)
finally:
try:
await pipeline.stop_async()
except Exception:
logger.exception("Error during pipeline stop")
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False, log_config=LOG_CONFIG)
|