| """TTS (Text-to-Speech) Maris runtime slānī.""" |
|
|
| from __future__ import annotations |
|
|
| import base64 |
| import io |
| import logging |
|
|
| from fastapi import APIRouter, HTTPException |
| from pydantic import BaseModel |
|
|
| from maris_core.memory_context import memory_store |
| from maris_core.utils.env import get_hf_model |
|
|
| logger = logging.getLogger(__name__) |
| router = APIRouter() |
|
|
|
|
| class TtsRequest(BaseModel): |
| text: str |
| voice: str = "maris" |
| language: str = "lv" |
| session_id: str | None = None |
| persona_id: str | None = None |
|
|
|
|
| class TtsResponse(BaseModel): |
| audio_url: str |
| duration_seconds: float |
|
|
|
|
| class TtsBytesRequest(BaseModel): |
| text: str |
| language: str = "lv" |
| voice: str = "maris" |
|
|
|
|
| @router.post("/tts", response_model=TtsResponse) |
| async def synthesize(req: TtsRequest) -> TtsResponse: |
| """Konvertē tekstu uz audio.""" |
| try: |
| model_id = get_hf_model("TTS_MODEL") |
| from transformers import pipeline as hf_pipeline |
|
|
| tts = hf_pipeline("text-to-speech", model_id, device=-1) |
| output = tts(req.text) |
|
|
| buf = io.BytesIO() |
| import scipy |
|
|
| scipy.io.wavfile.write(buf, rate=output["sampling_rate"], data=output["audio"].squeeze()) |
| b64 = base64.b64encode(buf.getvalue()).decode() |
| duration = len(output["audio"].squeeze()) / output["sampling_rate"] |
| session_id = (req.session_id or "").strip() |
| if session_id: |
| memory_store.remember_message(session_id, "assistant", req.text, source="voice_tts") |
|
|
| return TtsResponse( |
| audio_url=f"data:audio/wav;base64,{b64}", |
| duration_seconds=round(duration, 2), |
| ) |
| except Exception as exc: |
| logger.error("TTS kļūda: %s", exc) |
| raise HTTPException( |
| status_code=503, |
| detail="Maris AI TTS nav pieejams bez konfigurēta TTS_MODEL.", |
| ) from exc |
|
|
|
|
| @router.post("/tts_bytes") |
| async def synthesize_bytes(req: TtsBytesRequest) -> bytes: |
| """Atgriež raw audio baitus.""" |
| resp = await synthesize(TtsRequest(text=req.text, voice=req.voice, language=req.language)) |
| if resp.audio_url.startswith("data:"): |
| _, b64_data = resp.audio_url.split(",", 1) |
| return base64.b64decode(b64_data) |
| return b"" |
|
|