ishaq101's picture
fixing error stt and tts, empty chunk audio
c2e783d
import json
import logging
import logging.config
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, File, Form, UploadFile, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from src.pipeline import VoicePipeline
from typing import Literal
from src.config import (
DEEPGRAM_API_KEY, CARTESIA_API_KEY, CARTESIA_VOICE_ID,
SAMPLE_RATE, GOOGLE_API_KEY, GOOGLE_PROJECT_ID, STT_PROVIDER, TTS_PROVIDER, WAKE_WORD_ENABLED,
)
from src.stt.deepgram_rest import transcribe_audio as deepgram_transcribe
from src.stt.gemini_stt import transcribe_audio as gemini_stt_transcribe
from src.stt.chirp3_client import transcribe_audio as chirp3_transcribe
from src.tts.cartesia_client import synthesize_stream as cartesia_synthesize
from src.tts.gemini_client import synthesize_stream as gemini_synthesize, GEMINI_SAMPLE_RATE
LOG_CONFIG = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
},
},
"handlers": {
"default": {
"class": "logging.StreamHandler",
"formatter": "default",
},
},
"root": {
"level": "INFO",
"handlers": ["default"],
},
}
logging.config.dictConfig(LOG_CONFIG)
logger = logging.getLogger(__name__)
VERSION = "1.2.0"
app = FastAPI(title="Voice Agent Service", version=VERSION)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def health() -> JSONResponse:
body: dict = {
"message": "Welcome to the Voice Agent Service! Please use the /health endpoint to check service status.",
"version": VERSION,
}
return JSONResponse(status_code=200, content=body)
@app.get("/health")
async def health() -> JSONResponse:
stt_ready = bool(DEEPGRAM_API_KEY)
tts_ready = bool(CARTESIA_API_KEY and CARTESIA_VOICE_ID)
gemini_ready = bool(GOOGLE_API_KEY)
all_ready = stt_ready and tts_ready
body: dict = {
"status": "ok" if all_ready else "degraded",
"version": VERSION,
"stt_ready": stt_ready,
"tts_ready": tts_ready,
"gemini_tts_ready": gemini_ready,
"gemini_stt_ready": gemini_ready,
}
if not all_ready:
body["message"] = "One or more required configurations are missing. Check your .env file."
return JSONResponse(status_code=200 if all_ready else 503, content=body)
class TTSRequest(BaseModel):
text: str
provider: Literal["cartesia", "gemini"] = "gemini"
@app.post("/stt")
async def speech_to_text(
audio: UploadFile = File(...),
provider: str = Form(default="chirp3"),
) -> JSONResponse:
data = await audio.read()
if not data:
raise HTTPException(status_code=400, detail="Audio file is empty.")
mimetype = audio.content_type or "audio/wav"
if provider == "chirp3":
if not GOOGLE_PROJECT_ID:
raise HTTPException(status_code=503, detail="Chirp3 STT not configured: missing GOOGLE_PROJECT_ID.")
result = await chirp3_transcribe(data, mimetype=mimetype)
elif provider == "gemini":
if not GOOGLE_API_KEY:
raise HTTPException(status_code=503, detail="Gemini STT not configured: missing GOOGLE_API_KEY.")
result = await gemini_stt_transcribe(data, mimetype=mimetype)
else:
result = await deepgram_transcribe(data, mimetype=mimetype)
return JSONResponse(content=result)
@app.post("/tts")
async def text_to_speech(req: TTSRequest) -> StreamingResponse:
if not req.text.strip():
raise HTTPException(status_code=400, detail="text must not be empty.")
logger.info("TTS request: provider=%s, text_len=%d, text=%r", req.provider, len(req.text), req.text)
if req.provider == "gemini":
if not GOOGLE_API_KEY:
raise HTTPException(status_code=503, detail="Gemini TTS not configured.")
stream = gemini_synthesize(req.text)
sample_rate = GEMINI_SAMPLE_RATE
else:
stream = cartesia_synthesize(req.text)
sample_rate = SAMPLE_RATE
return StreamingResponse(
stream,
media_type="audio/pcm",
headers={
"X-Sample-Rate": str(sample_rate),
"X-Encoding": "pcm_s16le",
"X-Channels": "1",
},
)
@app.websocket("/ws/voice")
async def voice_ws(
ws: WebSocket,
stt_provider: str = Query(default=STT_PROVIDER),
tts_provider: str = Query(default=TTS_PROVIDER),
wake_word_enabled: bool = Query(default=WAKE_WORD_ENABLED),
) -> None:
await ws.accept()
logger.info(
"Client connected: %s (stt=%s, tts=%s)",
ws.client, stt_provider, tts_provider,
)
async def send_audio(chunk: bytes) -> None:
try:
await ws.send_bytes(chunk)
except WebSocketDisconnect:
pass
async def send_event(event: dict) -> None:
try:
await ws.send_text(json.dumps(event))
except WebSocketDisconnect:
pass
tts_sample_rate = GEMINI_SAMPLE_RATE if tts_provider == "gemini" else SAMPLE_RATE
await send_event({
"event": "tts_config",
"tts_provider": tts_provider,
"stt_provider": stt_provider,
"sample_rate": tts_sample_rate,
"encoding": "pcm_s16le",
"channels": 1,
})
pipeline = VoicePipeline(
send_audio=send_audio,
send_event=send_event,
stt_provider=stt_provider,
tts_provider=tts_provider,
wake_word_enabled=wake_word_enabled,
)
pipeline.start()
try:
while True:
data = await ws.receive()
if "bytes" in data and data["bytes"]:
pipeline.feed_audio(data["bytes"])
elif "text" in data and data["text"]:
try:
msg = json.loads(data["text"])
action = msg.get("action")
if action == "stop":
break
elif action == "ping":
await ws.send_text(json.dumps({"event": "pong"}))
elif action == "interrupt":
await pipeline.interrupt()
elif action == "speak":
text = msg.get("text", "").strip()
if text:
await pipeline.speak(text)
except json.JSONDecodeError:
pass
except WebSocketDisconnect:
logger.info("Client disconnected: %s", ws.client)
finally:
try:
await pipeline.stop_async()
except Exception:
logger.exception("Error during pipeline stop")
if __name__ == "__main__":
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=False, log_config=LOG_CONFIG)