Spaces:
Running
Running
GitHub Actions commited on
Commit ·
815b978
1
Parent(s): 9563e4a
Deploy a45bfc7
Browse files- app/api/transcribe.py +74 -0
- app/api/tts.py +26 -0
- app/core/config.py +6 -0
- app/main.py +15 -0
- app/models/speech.py +9 -0
- app/services/transcriber.py +63 -0
- app/services/tts_client.py +46 -0
- requirements.txt +1 -0
- tests/test_speech_endpoints.py +61 -0
app/api/transcribe.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status
|
| 4 |
+
|
| 5 |
+
from app.models.speech import TranscribeResponse
|
| 6 |
+
from app.security.jwt_auth import verify_jwt
|
| 7 |
+
|
| 8 |
+
router = APIRouter()
|
| 9 |
+
|
| 10 |
+
_ALLOWED_AUDIO_TYPES: frozenset[str] = frozenset(
|
| 11 |
+
{
|
| 12 |
+
"audio/webm",
|
| 13 |
+
"audio/wav",
|
| 14 |
+
"audio/x-wav",
|
| 15 |
+
"audio/mpeg",
|
| 16 |
+
"audio/mp3",
|
| 17 |
+
"audio/mp4",
|
| 18 |
+
"audio/ogg",
|
| 19 |
+
"audio/flac",
|
| 20 |
+
}
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@router.post("")
|
| 25 |
+
async def transcribe_endpoint(
|
| 26 |
+
request: Request,
|
| 27 |
+
audio: Annotated[UploadFile, File(...)],
|
| 28 |
+
_: Annotated[dict, Depends(verify_jwt)],
|
| 29 |
+
language: Annotated[str | None, Form()] = None,
|
| 30 |
+
) -> TranscribeResponse:
|
| 31 |
+
settings = request.app.state.settings
|
| 32 |
+
transcriber = request.app.state.transcriber
|
| 33 |
+
|
| 34 |
+
if not transcriber.is_configured:
|
| 35 |
+
raise HTTPException(
|
| 36 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 37 |
+
detail="Transcription service is not configured.",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
content_type = (audio.content_type or "").strip().lower()
|
| 41 |
+
if content_type not in _ALLOWED_AUDIO_TYPES:
|
| 42 |
+
raise HTTPException(
|
| 43 |
+
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
| 44 |
+
detail="Unsupported audio format.",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
audio_bytes = await audio.read()
|
| 48 |
+
if not audio_bytes:
|
| 49 |
+
raise HTTPException(
|
| 50 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 51 |
+
detail="Audio file is empty.",
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
if len(audio_bytes) > settings.TRANSCRIBE_MAX_UPLOAD_BYTES:
|
| 55 |
+
raise HTTPException(
|
| 56 |
+
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
| 57 |
+
detail="Audio file exceeds maximum allowed size.",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
language_code = language.strip().lower() if language and language.strip() else None
|
| 61 |
+
if language_code and len(language_code) > 10:
|
| 62 |
+
raise HTTPException(
|
| 63 |
+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
| 64 |
+
detail="Invalid language code.",
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
transcript = await transcriber.transcribe(
|
| 68 |
+
filename=audio.filename or "audio.webm",
|
| 69 |
+
content_type=content_type,
|
| 70 |
+
audio_bytes=audio_bytes,
|
| 71 |
+
language=language_code,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return TranscribeResponse(transcript=transcript)
|
app/api/tts.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Annotated
|
| 2 |
+
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
| 4 |
+
from fastapi.responses import Response
|
| 5 |
+
|
| 6 |
+
from app.models.speech import SynthesizeRequest
|
| 7 |
+
from app.security.jwt_auth import verify_jwt
|
| 8 |
+
|
| 9 |
+
router = APIRouter()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@router.post("")
|
| 13 |
+
async def synthesize_endpoint(
|
| 14 |
+
request: Request,
|
| 15 |
+
payload: SynthesizeRequest,
|
| 16 |
+
_: Annotated[dict, Depends(verify_jwt)],
|
| 17 |
+
) -> Response:
|
| 18 |
+
tts_client = request.app.state.tts_client
|
| 19 |
+
if not tts_client.is_configured:
|
| 20 |
+
raise HTTPException(
|
| 21 |
+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
| 22 |
+
detail="TTS service is not configured.",
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
audio_bytes = await tts_client.synthesize(payload.text.strip())
|
| 26 |
+
return Response(content=audio_bytes, media_type="audio/wav")
|
app/core/config.py
CHANGED
|
@@ -12,6 +12,7 @@ class Settings(BaseSettings):
|
|
| 12 |
OLLAMA_MODEL: Optional[str] = None
|
| 13 |
GROQ_MODEL_DEFAULT: str = "llama-3.1-8b-instant"
|
| 14 |
GROQ_MODEL_LARGE: str = "llama-3.3-70b-versatile"
|
|
|
|
| 15 |
|
| 16 |
# Vector
|
| 17 |
QDRANT_URL: str
|
|
@@ -67,6 +68,11 @@ class Settings(BaseSettings):
|
|
| 67 |
# In prod, the API Space calls the HF embedder/reranker Spaces via HTTP.
|
| 68 |
EMBEDDER_URL: str = "http://localhost:7860"
|
| 69 |
RERANKER_URL: str = "http://localhost:7861"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
| 72 |
|
|
|
|
| 12 |
OLLAMA_MODEL: Optional[str] = None
|
| 13 |
GROQ_MODEL_DEFAULT: str = "llama-3.1-8b-instant"
|
| 14 |
GROQ_MODEL_LARGE: str = "llama-3.3-70b-versatile"
|
| 15 |
+
GROQ_TRANSCRIBE_MODEL: str = "whisper-large-v3-turbo"
|
| 16 |
|
| 17 |
# Vector
|
| 18 |
QDRANT_URL: str
|
|
|
|
| 68 |
# In prod, the API Space calls the HF embedder/reranker Spaces via HTTP.
|
| 69 |
EMBEDDER_URL: str = "http://localhost:7860"
|
| 70 |
RERANKER_URL: str = "http://localhost:7861"
|
| 71 |
+
TTS_SPACE_URL: str = "http://localhost:7862"
|
| 72 |
+
|
| 73 |
+
# Speech-to-text upload constraints
|
| 74 |
+
TRANSCRIBE_MAX_UPLOAD_BYTES: int = 2 * 1024 * 1024
|
| 75 |
+
TRANSCRIBE_TIMEOUT_SECONDS: float = 25.0
|
| 76 |
|
| 77 |
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
| 78 |
|
app/main.py
CHANGED
|
@@ -14,6 +14,8 @@ from app.api.admin import router as admin_router
|
|
| 14 |
from app.api.chat import router as chat_router
|
| 15 |
from app.api.feedback import router as feedback_router
|
| 16 |
from app.api.health import router as health_router
|
|
|
|
|
|
|
| 17 |
from app.core.config import get_settings
|
| 18 |
from app.core.exceptions import AppError
|
| 19 |
from app.core.logging import get_logger
|
|
@@ -25,6 +27,8 @@ from app.services.github_log import GithubLog
|
|
| 25 |
from app.services.llm_client import get_llm_client, TpmBucket
|
| 26 |
from app.services.reranker import Reranker
|
| 27 |
from app.services.semantic_cache import SemanticCache
|
|
|
|
|
|
|
| 28 |
from app.services.conversation_store import ConversationStore
|
| 29 |
from qdrant_client import QdrantClient
|
| 30 |
|
|
@@ -156,6 +160,15 @@ async def lifespan(app: FastAPI):
|
|
| 156 |
context_path=settings.GEMINI_CONTEXT_PATH,
|
| 157 |
)
|
| 158 |
app.state.gemini_client = gemini_client
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
|
| 160 |
from app.services.vector_store import VectorStore
|
| 161 |
from app.security.guard_classifier import GuardClassifier
|
|
@@ -286,6 +299,8 @@ def create_app() -> FastAPI:
|
|
| 286 |
|
| 287 |
app.include_router(health_router, tags=["Health"])
|
| 288 |
app.include_router(chat_router, prefix="/chat", tags=["Chat"])
|
|
|
|
|
|
|
| 289 |
app.include_router(feedback_router, prefix="/chat", tags=["Feedback"])
|
| 290 |
app.include_router(admin_router, prefix="/admin", tags=["Admin"])
|
| 291 |
|
|
|
|
| 14 |
from app.api.chat import router as chat_router
|
| 15 |
from app.api.feedback import router as feedback_router
|
| 16 |
from app.api.health import router as health_router
|
| 17 |
+
from app.api.tts import router as tts_router
|
| 18 |
+
from app.api.transcribe import router as transcribe_router
|
| 19 |
from app.core.config import get_settings
|
| 20 |
from app.core.exceptions import AppError
|
| 21 |
from app.core.logging import get_logger
|
|
|
|
| 27 |
from app.services.llm_client import get_llm_client, TpmBucket
|
| 28 |
from app.services.reranker import Reranker
|
| 29 |
from app.services.semantic_cache import SemanticCache
|
| 30 |
+
from app.services.transcriber import GroqTranscriber
|
| 31 |
+
from app.services.tts_client import TTSClient
|
| 32 |
from app.services.conversation_store import ConversationStore
|
| 33 |
from qdrant_client import QdrantClient
|
| 34 |
|
|
|
|
| 160 |
context_path=settings.GEMINI_CONTEXT_PATH,
|
| 161 |
)
|
| 162 |
app.state.gemini_client = gemini_client
|
| 163 |
+
app.state.transcriber = GroqTranscriber(
|
| 164 |
+
api_key=settings.GROQ_API_KEY or "",
|
| 165 |
+
model=settings.GROQ_TRANSCRIBE_MODEL,
|
| 166 |
+
timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
|
| 167 |
+
)
|
| 168 |
+
app.state.tts_client = TTSClient(
|
| 169 |
+
tts_space_url=settings.TTS_SPACE_URL,
|
| 170 |
+
timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
|
| 171 |
+
)
|
| 172 |
|
| 173 |
from app.services.vector_store import VectorStore
|
| 174 |
from app.security.guard_classifier import GuardClassifier
|
|
|
|
| 299 |
|
| 300 |
app.include_router(health_router, tags=["Health"])
|
| 301 |
app.include_router(chat_router, prefix="/chat", tags=["Chat"])
|
| 302 |
+
app.include_router(transcribe_router, prefix="/transcribe", tags=["Transcribe"])
|
| 303 |
+
app.include_router(tts_router, prefix="/tts", tags=["TTS"])
|
| 304 |
app.include_router(feedback_router, prefix="/chat", tags=["Feedback"])
|
| 305 |
app.include_router(admin_router, prefix="/admin", tags=["Admin"])
|
| 306 |
|
app/models/speech.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TranscribeResponse(BaseModel):
|
| 5 |
+
transcript: str = Field(..., min_length=1, max_length=5000)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class SynthesizeRequest(BaseModel):
|
| 9 |
+
text: str = Field(..., min_length=1, max_length=300)
|
app/services/transcriber.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
from groq import AsyncGroq
|
| 5 |
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
|
| 6 |
+
|
| 7 |
+
from app.core.exceptions import GenerationError
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GroqTranscriber:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
api_key: str,
|
| 14 |
+
model: str,
|
| 15 |
+
timeout_seconds: float,
|
| 16 |
+
) -> None:
|
| 17 |
+
self._client = AsyncGroq(api_key=api_key) if api_key else None
|
| 18 |
+
self._model = model
|
| 19 |
+
self._timeout_seconds = timeout_seconds
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def is_configured(self) -> bool:
|
| 23 |
+
return self._client is not None
|
| 24 |
+
|
| 25 |
+
@retry(
|
| 26 |
+
stop=stop_after_attempt(2),
|
| 27 |
+
wait=wait_fixed(0.8),
|
| 28 |
+
retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)),
|
| 29 |
+
)
|
| 30 |
+
async def transcribe(
|
| 31 |
+
self,
|
| 32 |
+
filename: str,
|
| 33 |
+
content_type: str,
|
| 34 |
+
audio_bytes: bytes,
|
| 35 |
+
language: str | None = None,
|
| 36 |
+
) -> str:
|
| 37 |
+
if not self._client:
|
| 38 |
+
raise GenerationError("Transcriber is not configured with GROQ_API_KEY")
|
| 39 |
+
|
| 40 |
+
async def _call() -> str:
|
| 41 |
+
response = await self._client.audio.transcriptions.create(
|
| 42 |
+
file=(filename, audio_bytes, content_type),
|
| 43 |
+
model=self._model,
|
| 44 |
+
temperature=0,
|
| 45 |
+
language=language,
|
| 46 |
+
)
|
| 47 |
+
text = getattr(response, "text", None)
|
| 48 |
+
if isinstance(text, str) and text.strip():
|
| 49 |
+
return text.strip()
|
| 50 |
+
if isinstance(response, dict):
|
| 51 |
+
value = response.get("text")
|
| 52 |
+
if isinstance(value, str) and value.strip():
|
| 53 |
+
return value.strip()
|
| 54 |
+
raise GenerationError("Transcription response did not contain text")
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
return await asyncio.wait_for(_call(), timeout=self._timeout_seconds)
|
| 58 |
+
except TimeoutError as exc:
|
| 59 |
+
raise GenerationError("Transcription timed out") from exc
|
| 60 |
+
except GenerationError:
|
| 61 |
+
raise
|
| 62 |
+
except Exception as exc:
|
| 63 |
+
raise GenerationError("Transcription failed", context={"error": str(exc)}) from exc
|
app/services/tts_client.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
|
| 5 |
+
from app.core.exceptions import GenerationError
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TTSClient:
|
| 9 |
+
def __init__(self, tts_space_url: str, timeout_seconds: float) -> None:
|
| 10 |
+
self._tts_space_url = tts_space_url.rstrip("/")
|
| 11 |
+
self._timeout_seconds = timeout_seconds
|
| 12 |
+
|
| 13 |
+
@property
|
| 14 |
+
def is_configured(self) -> bool:
|
| 15 |
+
return bool(self._tts_space_url)
|
| 16 |
+
|
| 17 |
+
async def synthesize(self, text: str) -> bytes:
|
| 18 |
+
if not self.is_configured:
|
| 19 |
+
raise GenerationError("TTS client is not configured")
|
| 20 |
+
|
| 21 |
+
async def _call() -> bytes:
|
| 22 |
+
async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
|
| 23 |
+
response = await client.post(
|
| 24 |
+
f"{self._tts_space_url}/synthesize",
|
| 25 |
+
json={"text": text},
|
| 26 |
+
headers={"Content-Type": "application/json"},
|
| 27 |
+
)
|
| 28 |
+
response.raise_for_status()
|
| 29 |
+
audio_bytes = response.content
|
| 30 |
+
if not audio_bytes:
|
| 31 |
+
raise GenerationError("TTS response was empty")
|
| 32 |
+
return audio_bytes
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
return await asyncio.wait_for(_call(), timeout=self._timeout_seconds)
|
| 36 |
+
except TimeoutError as exc:
|
| 37 |
+
raise GenerationError("TTS request timed out") from exc
|
| 38 |
+
except httpx.HTTPStatusError as exc:
|
| 39 |
+
raise GenerationError(
|
| 40 |
+
"TTS upstream returned an error",
|
| 41 |
+
context={"status_code": exc.response.status_code},
|
| 42 |
+
) from exc
|
| 43 |
+
except GenerationError:
|
| 44 |
+
raise
|
| 45 |
+
except Exception as exc:
|
| 46 |
+
raise GenerationError("TTS synthesis failed", context={"error": str(exc)}) from exc
|
requirements.txt
CHANGED
|
@@ -9,6 +9,7 @@
|
|
| 9 |
fastapi>=0.115.0
|
| 10 |
uvicorn[standard]>=0.29.0
|
| 11 |
uvloop>=0.19.0
|
|
|
|
| 12 |
pydantic-settings>=2.2.1
|
| 13 |
langgraph>=0.2.0
|
| 14 |
qdrant-client==1.9.1
|
|
|
|
| 9 |
fastapi>=0.115.0
|
| 10 |
uvicorn[standard]>=0.29.0
|
| 11 |
uvloop>=0.19.0
|
| 12 |
+
python-multipart>=0.0.9
|
| 13 |
pydantic-settings>=2.2.1
|
| 14 |
langgraph>=0.2.0
|
| 15 |
qdrant-client==1.9.1
|
tests/test_speech_endpoints.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_transcribe_requires_auth(app_client):
|
| 5 |
+
response = app_client.post(
|
| 6 |
+
"/transcribe",
|
| 7 |
+
files={"audio": ("sample.webm", b"abc", "audio/webm")},
|
| 8 |
+
)
|
| 9 |
+
assert response.status_code == 401
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def test_transcribe_success(app_client, valid_token):
|
| 13 |
+
async def fake_transcribe(filename, content_type, audio_bytes, language=None):
|
| 14 |
+
await asyncio.sleep(0)
|
| 15 |
+
return "hello from voice"
|
| 16 |
+
|
| 17 |
+
app_client.app.state.transcriber.transcribe = fake_transcribe
|
| 18 |
+
|
| 19 |
+
response = app_client.post(
|
| 20 |
+
"/transcribe",
|
| 21 |
+
files={"audio": ("sample.webm", b"abc", "audio/webm")},
|
| 22 |
+
headers={"Authorization": f"Bearer {valid_token}"},
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
assert response.status_code == 200
|
| 26 |
+
assert response.json()["transcript"] == "hello from voice"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_transcribe_rejects_oversized_audio(app_client, valid_token):
|
| 30 |
+
app_client.app.state.settings.TRANSCRIBE_MAX_UPLOAD_BYTES = 2
|
| 31 |
+
|
| 32 |
+
response = app_client.post(
|
| 33 |
+
"/transcribe",
|
| 34 |
+
files={"audio": ("sample.webm", b"abcdef", "audio/webm")},
|
| 35 |
+
headers={"Authorization": f"Bearer {valid_token}"},
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
assert response.status_code == 413
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_tts_requires_auth(app_client):
|
| 42 |
+
response = app_client.post("/tts", json={"text": "Hello world"})
|
| 43 |
+
assert response.status_code == 401
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def test_tts_success(app_client, valid_token):
|
| 47 |
+
async def fake_synthesize(text):
|
| 48 |
+
await asyncio.sleep(0)
|
| 49 |
+
return b"RIFF....fake"
|
| 50 |
+
|
| 51 |
+
app_client.app.state.tts_client.synthesize = fake_synthesize
|
| 52 |
+
|
| 53 |
+
response = app_client.post(
|
| 54 |
+
"/tts",
|
| 55 |
+
json={"text": "Hello world"},
|
| 56 |
+
headers={"Authorization": f"Bearer {valid_token}"},
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
assert response.status_code == 200
|
| 60 |
+
assert response.headers.get("content-type", "").startswith("audio/wav")
|
| 61 |
+
assert response.content == b"RIFF....fake"
|