Spaces:
Running
Running
GitHub Actions commited on
Commit ·
4a9ec15
1
Parent(s): e651eb1
Deploy 92e36db
Browse files- app/api/transcribe.py +5 -4
- app/api/tts.py +10 -6
- app/core/config.py +8 -0
- app/main.py +2 -0
- app/security/rate_limiter.py +6 -0
- app/services/transcriber.py +16 -14
- app/services/tts_client.py +41 -2
- requirements.txt +3 -1
- test_stream.py +28 -0
- tests/test_speech_endpoints.py +8 -7
- tests/test_transcriber_normalization.py +17 -3
app/api/transcribe.py
CHANGED
|
@@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, Uplo
|
|
| 4 |
|
| 5 |
from app.models.speech import TranscribeResponse
|
| 6 |
from app.security.jwt_auth import verify_jwt
|
|
|
|
| 7 |
|
| 8 |
router = APIRouter()
|
| 9 |
|
|
@@ -22,6 +23,7 @@ _ALLOWED_AUDIO_TYPES: frozenset[str] = frozenset(
|
|
| 22 |
|
| 23 |
|
| 24 |
@router.post("")
|
|
|
|
| 25 |
async def transcribe_endpoint(
|
| 26 |
request: Request,
|
| 27 |
audio: Annotated[UploadFile, File(...)],
|
|
@@ -44,14 +46,13 @@ async def transcribe_endpoint(
|
|
| 44 |
detail="Unsupported audio format.",
|
| 45 |
)
|
| 46 |
|
| 47 |
-
|
| 48 |
-
if not audio_bytes:
|
| 49 |
raise HTTPException(
|
| 50 |
status_code=status.HTTP_400_BAD_REQUEST,
|
| 51 |
detail="Audio file is empty.",
|
| 52 |
)
|
| 53 |
|
| 54 |
-
if
|
| 55 |
raise HTTPException(
|
| 56 |
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
| 57 |
detail="Audio file exceeds maximum allowed size.",
|
|
@@ -67,7 +68,7 @@ async def transcribe_endpoint(
|
|
| 67 |
transcript = await transcriber.transcribe(
|
| 68 |
filename=audio.filename or "audio.webm",
|
| 69 |
content_type=content_type,
|
| 70 |
-
|
| 71 |
language=language_code,
|
| 72 |
)
|
| 73 |
|
|
|
|
| 4 |
|
| 5 |
from app.models.speech import TranscribeResponse
|
| 6 |
from app.security.jwt_auth import verify_jwt
|
| 7 |
+
from app.security.rate_limiter import transcribe_rate_limit
|
| 8 |
|
| 9 |
router = APIRouter()
|
| 10 |
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@router.post("")
|
| 26 |
+
@transcribe_rate_limit()
|
| 27 |
async def transcribe_endpoint(
|
| 28 |
request: Request,
|
| 29 |
audio: Annotated[UploadFile, File(...)],
|
|
|
|
| 46 |
detail="Unsupported audio format.",
|
| 47 |
)
|
| 48 |
|
| 49 |
+
if audio.size is None or audio.size == 0:
|
|
|
|
| 50 |
raise HTTPException(
|
| 51 |
status_code=status.HTTP_400_BAD_REQUEST,
|
| 52 |
detail="Audio file is empty.",
|
| 53 |
)
|
| 54 |
|
| 55 |
+
if audio.size > settings.TRANSCRIBE_MAX_UPLOAD_BYTES:
|
| 56 |
raise HTTPException(
|
| 57 |
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
| 58 |
detail="Audio file exceeds maximum allowed size.",
|
|
|
|
| 68 |
transcript = await transcriber.transcribe(
|
| 69 |
filename=audio.filename or "audio.webm",
|
| 70 |
content_type=content_type,
|
| 71 |
+
audio_file=audio.file,
|
| 72 |
language=language_code,
|
| 73 |
)
|
| 74 |
|
app/api/tts.py
CHANGED
|
@@ -1,20 +1,22 @@
|
|
| 1 |
from typing import Annotated
|
| 2 |
|
| 3 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
| 4 |
-
from fastapi.responses import
|
| 5 |
|
| 6 |
from app.models.speech import SynthesizeRequest
|
| 7 |
from app.security.jwt_auth import verify_jwt
|
|
|
|
| 8 |
|
| 9 |
router = APIRouter()
|
| 10 |
|
| 11 |
|
| 12 |
@router.post("")
|
|
|
|
| 13 |
async def synthesize_endpoint(
|
| 14 |
request: Request,
|
| 15 |
payload: SynthesizeRequest,
|
| 16 |
_: Annotated[dict, Depends(verify_jwt)],
|
| 17 |
-
) ->
|
| 18 |
tts_client = request.app.state.tts_client
|
| 19 |
if not tts_client.is_configured:
|
| 20 |
raise HTTPException(
|
|
@@ -22,8 +24,10 @@ async def synthesize_endpoint(
|
|
| 22 |
detail="TTS service is not configured.",
|
| 23 |
)
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
)
|
| 29 |
-
return Response(content=audio_bytes, media_type="audio/wav")
|
|
|
|
| 1 |
from typing import Annotated
|
| 2 |
|
| 3 |
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
| 4 |
+
from fastapi.responses import StreamingResponse
|
| 5 |
|
| 6 |
from app.models.speech import SynthesizeRequest
|
| 7 |
from app.security.jwt_auth import verify_jwt
|
| 8 |
+
from app.security.rate_limiter import tts_rate_limit
|
| 9 |
|
| 10 |
router = APIRouter()
|
| 11 |
|
| 12 |
|
| 13 |
@router.post("")
|
| 14 |
+
@tts_rate_limit()
|
| 15 |
async def synthesize_endpoint(
|
| 16 |
request: Request,
|
| 17 |
payload: SynthesizeRequest,
|
| 18 |
_: Annotated[dict, Depends(verify_jwt)],
|
| 19 |
+
) -> StreamingResponse:
|
| 20 |
tts_client = request.app.state.tts_client
|
| 21 |
if not tts_client.is_configured:
|
| 22 |
raise HTTPException(
|
|
|
|
| 24 |
detail="TTS service is not configured.",
|
| 25 |
)
|
| 26 |
|
| 27 |
+
return StreamingResponse(
|
| 28 |
+
tts_client.synthesize_stream(
|
| 29 |
+
payload.text.strip(),
|
| 30 |
+
voice=payload.voice.strip().lower(),
|
| 31 |
+
),
|
| 32 |
+
media_type="audio/wav"
|
| 33 |
)
|
|
|
app/core/config.py
CHANGED
|
@@ -65,6 +65,14 @@ class Settings(BaseSettings):
|
|
| 65 |
# Speech-to-text upload constraints
|
| 66 |
TRANSCRIBE_MAX_UPLOAD_BYTES: int = 2 * 1024 * 1024
|
| 67 |
TRANSCRIBE_TIMEOUT_SECONDS: float = 25.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
| 70 |
|
|
|
|
| 65 |
# Speech-to-text upload constraints
|
| 66 |
TRANSCRIBE_MAX_UPLOAD_BYTES: int = 2 * 1024 * 1024
|
| 67 |
TRANSCRIBE_TIMEOUT_SECONDS: float = 25.0
|
| 68 |
+
TRANSCRIBE_DEFAULT_LANGUAGE: str = "en"
|
| 69 |
+
TRANSCRIBE_REPLACEMENTS: dict[str, str] = {
|
| 70 |
+
r"\bwalk experience\b": "work experience",
|
| 71 |
+
r"\btext stack\b": "tech stack",
|
| 72 |
+
r"\bprofessional sitting\b": "professional setting",
|
| 73 |
+
r"\btech stocks\b": "tech stack",
|
| 74 |
+
r"\bwhat tech stack does he\s+used\b": "what tech stack does he use",
|
| 75 |
+
}
|
| 76 |
|
| 77 |
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
| 78 |
|
app/main.py
CHANGED
|
@@ -144,6 +144,8 @@ async def lifespan(app: FastAPI):
|
|
| 144 |
api_key=settings.GROQ_API_KEY or "",
|
| 145 |
model=settings.GROQ_TRANSCRIBE_MODEL,
|
| 146 |
timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
|
|
|
|
|
|
|
| 147 |
)
|
| 148 |
app.state.tts_client = TTSClient(
|
| 149 |
tts_space_url=settings.TTS_SPACE_URL,
|
|
|
|
| 144 |
api_key=settings.GROQ_API_KEY or "",
|
| 145 |
model=settings.GROQ_TRANSCRIBE_MODEL,
|
| 146 |
timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
|
| 147 |
+
default_language=settings.TRANSCRIBE_DEFAULT_LANGUAGE,
|
| 148 |
+
replacements=settings.TRANSCRIBE_REPLACEMENTS,
|
| 149 |
)
|
| 150 |
app.state.tts_client = TTSClient(
|
| 151 |
tts_space_url=settings.TTS_SPACE_URL,
|
app/security/rate_limiter.py
CHANGED
|
@@ -18,3 +18,9 @@ async def custom_rate_limit_handler(request: Request, exc: Exception) -> JSONRes
|
|
| 18 |
# Decorator factory chat_rate_limit that applies 20/minute limit.
|
| 19 |
def chat_rate_limit() -> Callable:
|
| 20 |
return limiter.limit("20/minute")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
# Decorator factory chat_rate_limit that applies 20/minute limit.
|
| 19 |
def chat_rate_limit() -> Callable:
|
| 20 |
return limiter.limit("20/minute")
|
| 21 |
+
|
| 22 |
+
def tts_rate_limit() -> Callable:
|
| 23 |
+
return limiter.limit("10/minute")
|
| 24 |
+
|
| 25 |
+
def transcribe_rate_limit() -> Callable:
|
| 26 |
+
return limiter.limit("20/minute")
|
app/services/transcriber.py
CHANGED
|
@@ -9,19 +9,12 @@ from app.core.exceptions import GenerationError
|
|
| 9 |
|
| 10 |
_FILLER_PREFIX_RE = re.compile(r"^\s*(uh+|um+|erm+|like|you know|please|hey)\s+", re.IGNORECASE)
|
| 11 |
_MULTISPACE_RE = re.compile(r"\s+")
|
| 12 |
-
_TRANSCRIPT_REPLACEMENTS: tuple[tuple[re.Pattern[str], str], ...] = (
|
| 13 |
-
(re.compile(r"\bwalk experience\b", re.IGNORECASE), "work experience"),
|
| 14 |
-
(re.compile(r"\btext stack\b", re.IGNORECASE), "tech stack"),
|
| 15 |
-
(re.compile(r"\bprofessional sitting\b", re.IGNORECASE), "professional setting"),
|
| 16 |
-
(re.compile(r"\btech stocks\b", re.IGNORECASE), "tech stack"),
|
| 17 |
-
(re.compile(r"\bwhat tech stack does he\s+used\b", re.IGNORECASE), "what tech stack does he use"),
|
| 18 |
-
)
|
| 19 |
|
| 20 |
|
| 21 |
-
def _normalise_transcript_text(text: str) -> str:
|
| 22 |
cleaned = text.strip()
|
| 23 |
cleaned = _FILLER_PREFIX_RE.sub("", cleaned)
|
| 24 |
-
for pattern, replacement in
|
| 25 |
cleaned = pattern.sub(replacement, cleaned)
|
| 26 |
cleaned = _MULTISPACE_RE.sub(" ", cleaned)
|
| 27 |
return cleaned.strip()
|
|
@@ -33,10 +26,17 @@ class GroqTranscriber:
|
|
| 33 |
api_key: str,
|
| 34 |
model: str,
|
| 35 |
timeout_seconds: float,
|
|
|
|
|
|
|
| 36 |
) -> None:
|
| 37 |
self._client = AsyncGroq(api_key=api_key) if api_key else None
|
| 38 |
self._model = model
|
| 39 |
self._timeout_seconds = timeout_seconds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
@property
|
| 42 |
def is_configured(self) -> bool:
|
|
@@ -51,26 +51,28 @@ class GroqTranscriber:
|
|
| 51 |
self,
|
| 52 |
filename: str,
|
| 53 |
content_type: str,
|
| 54 |
-
|
| 55 |
language: str | None = None,
|
| 56 |
) -> str:
|
| 57 |
if not self._client:
|
| 58 |
raise GenerationError("Transcriber is not configured with GROQ_API_KEY")
|
| 59 |
|
|
|
|
|
|
|
| 60 |
async def _call() -> str:
|
| 61 |
response = await self._client.audio.transcriptions.create(
|
| 62 |
-
file=(filename,
|
| 63 |
model=self._model,
|
| 64 |
temperature=0,
|
| 65 |
-
language=
|
| 66 |
)
|
| 67 |
text = getattr(response, "text", None)
|
| 68 |
if isinstance(text, str) and text.strip():
|
| 69 |
-
return _normalise_transcript_text(text)
|
| 70 |
if isinstance(response, dict):
|
| 71 |
value = response.get("text")
|
| 72 |
if isinstance(value, str) and value.strip():
|
| 73 |
-
return _normalise_transcript_text(value)
|
| 74 |
raise GenerationError("Transcription response did not contain text")
|
| 75 |
|
| 76 |
try:
|
|
|
|
| 9 |
|
| 10 |
_FILLER_PREFIX_RE = re.compile(r"^\s*(uh+|um+|erm+|like|you know|please|hey)\s+", re.IGNORECASE)
|
| 11 |
_MULTISPACE_RE = re.compile(r"\s+")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
+
def _normalise_transcript_text(text: str, replacements: tuple[tuple[re.Pattern[str], str], ...]) -> str:
|
| 15 |
cleaned = text.strip()
|
| 16 |
cleaned = _FILLER_PREFIX_RE.sub("", cleaned)
|
| 17 |
+
for pattern, replacement in replacements:
|
| 18 |
cleaned = pattern.sub(replacement, cleaned)
|
| 19 |
cleaned = _MULTISPACE_RE.sub(" ", cleaned)
|
| 20 |
return cleaned.strip()
|
|
|
|
| 26 |
api_key: str,
|
| 27 |
model: str,
|
| 28 |
timeout_seconds: float,
|
| 29 |
+
default_language: str = "en",
|
| 30 |
+
replacements: dict[str, str] | None = None,
|
| 31 |
) -> None:
|
| 32 |
self._client = AsyncGroq(api_key=api_key) if api_key else None
|
| 33 |
self._model = model
|
| 34 |
self._timeout_seconds = timeout_seconds
|
| 35 |
+
self._default_language = default_language
|
| 36 |
+
self._replacements = tuple(
|
| 37 |
+
(re.compile(pattern, re.IGNORECASE), replacement)
|
| 38 |
+
for pattern, replacement in (replacements or {}).items()
|
| 39 |
+
)
|
| 40 |
|
| 41 |
@property
|
| 42 |
def is_configured(self) -> bool:
|
|
|
|
| 51 |
self,
|
| 52 |
filename: str,
|
| 53 |
content_type: str,
|
| 54 |
+
audio_file,
|
| 55 |
language: str | None = None,
|
| 56 |
) -> str:
|
| 57 |
if not self._client:
|
| 58 |
raise GenerationError("Transcriber is not configured with GROQ_API_KEY")
|
| 59 |
|
| 60 |
+
target_language = language if language else self._default_language
|
| 61 |
+
|
| 62 |
async def _call() -> str:
|
| 63 |
response = await self._client.audio.transcriptions.create(
|
| 64 |
+
file=(filename, audio_file, content_type),
|
| 65 |
model=self._model,
|
| 66 |
temperature=0,
|
| 67 |
+
language=target_language,
|
| 68 |
)
|
| 69 |
text = getattr(response, "text", None)
|
| 70 |
if isinstance(text, str) and text.strip():
|
| 71 |
+
return _normalise_transcript_text(text, self._replacements)
|
| 72 |
if isinstance(response, dict):
|
| 73 |
value = response.get("text")
|
| 74 |
if isinstance(value, str) and value.strip():
|
| 75 |
+
return _normalise_transcript_text(value, self._replacements)
|
| 76 |
raise GenerationError("Transcription response did not contain text")
|
| 77 |
|
| 78 |
try:
|
app/services/tts_client.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import asyncio
|
| 2 |
-
|
| 3 |
import httpx
|
| 4 |
-
|
| 5 |
from app.core.exceptions import GenerationError
|
| 6 |
|
| 7 |
|
|
@@ -44,3 +42,44 @@ class TTSClient:
|
|
| 44 |
raise
|
| 45 |
except Exception as exc:
|
| 46 |
raise GenerationError("TTS synthesis failed", context={"error": str(exc)}) from exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
import httpx
|
|
|
|
| 3 |
from app.core.exceptions import GenerationError
|
| 4 |
|
| 5 |
|
|
|
|
| 42 |
raise
|
| 43 |
except Exception as exc:
|
| 44 |
raise GenerationError("TTS synthesis failed", context={"error": str(exc)}) from exc
|
| 45 |
+
|
| 46 |
+
async def synthesize_stream(self, text: str, voice: str = "am_adam"):
|
| 47 |
+
text = text.strip()
|
| 48 |
+
if not text:
|
| 49 |
+
raise GenerationError("TTS request text is empty")
|
| 50 |
+
|
| 51 |
+
loop = asyncio.get_running_loop()
|
| 52 |
+
queue = asyncio.Queue()
|
| 53 |
+
|
| 54 |
+
def _worker():
|
| 55 |
+
try:
|
| 56 |
+
generator = self._pipeline(text, voice=voice, speed=1, split_pattern=r'\n+')
|
| 57 |
+
for gs, ps, audio in generator:
|
| 58 |
+
if audio is not None:
|
| 59 |
+
import numpy as np
|
| 60 |
+
pcm_audio = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16).tobytes()
|
| 61 |
+
loop.call_soon_threadsafe(queue.put_nowait, pcm_audio)
|
| 62 |
+
except Exception as e:
|
| 63 |
+
loop.call_soon_threadsafe(queue.put_nowait, e)
|
| 64 |
+
finally:
|
| 65 |
+
loop.call_soon_threadsafe(queue.put_nowait, None)
|
| 66 |
+
|
| 67 |
+
import threading
|
| 68 |
+
thread = threading.Thread(target=_worker)
|
| 69 |
+
thread.start()
|
| 70 |
+
|
| 71 |
+
import struct
|
| 72 |
+
# 44-byte WAV header with 0xFFFFFFFF for sizes (streaming)
|
| 73 |
+
yield struct.pack('<4sI4s4sIHHIIHH4sI',
|
| 74 |
+
b'RIFF', 0xFFFFFFFF, b'WAVE',
|
| 75 |
+
b'fmt ', 16, 1, 1, 24000, 48000, 2, 16,
|
| 76 |
+
b'data', 0xFFFFFFFF
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
while True:
|
| 80 |
+
chunk = await queue.get()
|
| 81 |
+
if chunk is None:
|
| 82 |
+
break
|
| 83 |
+
if isinstance(chunk, Exception):
|
| 84 |
+
raise GenerationError("TTS synthesis stream failed", context={"error": str(chunk)}) from chunk
|
| 85 |
+
yield chunk
|
requirements.txt
CHANGED
|
@@ -26,4 +26,6 @@ google-genai>=1.0.0
|
|
| 26 |
# fastembed: powers BM25 sparse retrieval (Stage 2). Qdrant/bm25 vocabulary
|
| 27 |
# downloads ~5 MB on first use then runs fully local — no GPU, no network at query time.
|
| 28 |
fastembed>=0.3.6
|
| 29 |
-
toon_format @ git+https://github.com/toon-format/toon-python.git
|
|
|
|
|
|
|
|
|
| 26 |
# fastembed: powers BM25 sparse retrieval (Stage 2). Qdrant/bm25 vocabulary
|
| 27 |
# downloads ~5 MB on first use then runs fully local — no GPU, no network at query time.
|
| 28 |
fastembed>=0.3.6
|
| 29 |
+
toon_format @ git+https://github.com/toon-format/toon-python.git
|
| 30 |
+
kokoro>=0.9.0
|
| 31 |
+
soundfile>=0.13.0
|
test_stream.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from kokoro import KPipeline
|
| 3 |
+
import numpy as np
|
| 4 |
+
import struct
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
async def main():
|
| 8 |
+
pipeline = KPipeline(lang_code='a')
|
| 9 |
+
generator = pipeline("Hello, this is a test of streaming audio from Kokoro.", voice="am_adam", speed=1, split_pattern=r'\n+')
|
| 10 |
+
|
| 11 |
+
with open("test_stream.wav", "wb") as f:
|
| 12 |
+
# Write WAV header
|
| 13 |
+
# chunk_size = 36 + data_size
|
| 14 |
+
header = struct.pack('<4sI4s4sIHHIIHH4sI',
|
| 15 |
+
b'RIFF', 0xFFFFFFFF, b'WAVE',
|
| 16 |
+
b'fmt ', 16, 1, 1, 24000, 48000, 2, 16,
|
| 17 |
+
b'data', 0xFFFFFFFF
|
| 18 |
+
)
|
| 19 |
+
f.write(header)
|
| 20 |
+
|
| 21 |
+
for gs, ps, audio in generator:
|
| 22 |
+
if audio is not None:
|
| 23 |
+
print("Got chunk:", len(audio))
|
| 24 |
+
pcm = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16).tobytes()
|
| 25 |
+
f.write(pcm)
|
| 26 |
+
|
| 27 |
+
if __name__ == "__main__":
|
| 28 |
+
asyncio.run(main())
|
tests/test_speech_endpoints.py
CHANGED
|
@@ -10,7 +10,7 @@ def test_transcribe_requires_auth(app_client):
|
|
| 10 |
|
| 11 |
|
| 12 |
def test_transcribe_success(app_client, valid_token):
|
| 13 |
-
async def fake_transcribe(filename, content_type,
|
| 14 |
await asyncio.sleep(0)
|
| 15 |
return "hello from voice"
|
| 16 |
|
|
@@ -46,13 +46,13 @@ def test_tts_requires_auth(app_client):
|
|
| 46 |
def test_tts_success(app_client, valid_token):
|
| 47 |
captured: dict[str, str] = {}
|
| 48 |
|
| 49 |
-
async def
|
| 50 |
await asyncio.sleep(0)
|
| 51 |
captured["text"] = text
|
| 52 |
captured["voice"] = voice
|
| 53 |
-
|
| 54 |
|
| 55 |
-
app_client.app.state.tts_client.
|
| 56 |
|
| 57 |
response = app_client.post(
|
| 58 |
"/tts",
|
|
@@ -62,6 +62,7 @@ def test_tts_success(app_client, valid_token):
|
|
| 62 |
|
| 63 |
assert response.status_code == 200
|
| 64 |
assert response.headers.get("content-type", "").startswith("audio/wav")
|
|
|
|
| 65 |
assert response.content == b"RIFF....fake"
|
| 66 |
assert captured["text"] == "Hello world"
|
| 67 |
assert captured["voice"] == "am_adam"
|
|
@@ -70,12 +71,12 @@ def test_tts_success(app_client, valid_token):
|
|
| 70 |
def test_tts_uses_provided_voice(app_client, valid_token):
|
| 71 |
captured: dict[str, str] = {}
|
| 72 |
|
| 73 |
-
async def
|
| 74 |
await asyncio.sleep(0)
|
| 75 |
captured["voice"] = voice
|
| 76 |
-
|
| 77 |
|
| 78 |
-
app_client.app.state.tts_client.
|
| 79 |
|
| 80 |
response = app_client.post(
|
| 81 |
"/tts",
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def test_transcribe_success(app_client, valid_token):
|
| 13 |
+
async def fake_transcribe(filename, content_type, audio_file, language=None):
|
| 14 |
await asyncio.sleep(0)
|
| 15 |
return "hello from voice"
|
| 16 |
|
|
|
|
| 46 |
def test_tts_success(app_client, valid_token):
|
| 47 |
captured: dict[str, str] = {}
|
| 48 |
|
| 49 |
+
async def fake_synthesize_stream(text, voice="am_adam"):
|
| 50 |
await asyncio.sleep(0)
|
| 51 |
captured["text"] = text
|
| 52 |
captured["voice"] = voice
|
| 53 |
+
yield b"RIFF....fake"
|
| 54 |
|
| 55 |
+
app_client.app.state.tts_client.synthesize_stream = fake_synthesize_stream
|
| 56 |
|
| 57 |
response = app_client.post(
|
| 58 |
"/tts",
|
|
|
|
| 62 |
|
| 63 |
assert response.status_code == 200
|
| 64 |
assert response.headers.get("content-type", "").startswith("audio/wav")
|
| 65 |
+
# StreamingResponse returns chunks, so response.content concatenates them
|
| 66 |
assert response.content == b"RIFF....fake"
|
| 67 |
assert captured["text"] == "Hello world"
|
| 68 |
assert captured["voice"] == "am_adam"
|
|
|
|
| 71 |
def test_tts_uses_provided_voice(app_client, valid_token):
|
| 72 |
captured: dict[str, str] = {}
|
| 73 |
|
| 74 |
+
async def fake_synthesize_stream(text, voice="am_adam"):
|
| 75 |
await asyncio.sleep(0)
|
| 76 |
captured["voice"] = voice
|
| 77 |
+
yield b"RIFF....fake"
|
| 78 |
|
| 79 |
+
app_client.app.state.tts_client.synthesize_stream = fake_synthesize_stream
|
| 80 |
|
| 81 |
response = app_client.post(
|
| 82 |
"/tts",
|
tests/test_transcriber_normalization.py
CHANGED
|
@@ -1,15 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from app.services.transcriber import _normalise_transcript_text
|
| 2 |
|
| 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
def test_normalise_walk_experience_to_work_experience() -> None:
|
| 5 |
query = "uh what is his walk experience in a professional setting"
|
| 6 |
-
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def test_normalise_text_stack_to_tech_stack() -> None:
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
def test_keeps_clean_transcript_unchanged() -> None:
|
| 14 |
original = "What technologies and skills does he work with?"
|
| 15 |
-
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from app.core.config import get_settings
|
| 4 |
from app.services.transcriber import _normalise_transcript_text
|
| 5 |
|
| 6 |
|
| 7 |
+
def _get_test_replacements():
|
| 8 |
+
replacements = get_settings().TRANSCRIBE_REPLACEMENTS
|
| 9 |
+
return tuple(
|
| 10 |
+
(re.compile(pattern, re.IGNORECASE), replacement)
|
| 11 |
+
for pattern, replacement in replacements.items()
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
def test_normalise_walk_experience_to_work_experience() -> None:
|
| 16 |
query = "uh what is his walk experience in a professional setting"
|
| 17 |
+
replacements = _get_test_replacements()
|
| 18 |
+
assert _normalise_transcript_text(query, replacements) == "what is his work experience in a professional setting"
|
| 19 |
|
| 20 |
|
| 21 |
def test_normalise_text_stack_to_tech_stack() -> None:
|
| 22 |
+
replacements = _get_test_replacements()
|
| 23 |
+
assert _normalise_transcript_text("what text stack does he use", replacements) == "what tech stack does he use"
|
| 24 |
|
| 25 |
|
| 26 |
def test_keeps_clean_transcript_unchanged() -> None:
|
| 27 |
original = "What technologies and skills does he work with?"
|
| 28 |
+
replacements = _get_test_replacements()
|
| 29 |
+
assert _normalise_transcript_text(original, replacements) == original
|