GitHub Actions commited on
Commit
4a9ec15
·
1 Parent(s): e651eb1

Deploy 92e36db

Browse files
app/api/transcribe.py CHANGED
@@ -4,6 +4,7 @@ from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, Uplo
4
 
5
  from app.models.speech import TranscribeResponse
6
  from app.security.jwt_auth import verify_jwt
 
7
 
8
  router = APIRouter()
9
 
@@ -22,6 +23,7 @@ _ALLOWED_AUDIO_TYPES: frozenset[str] = frozenset(
22
 
23
 
24
  @router.post("")
 
25
  async def transcribe_endpoint(
26
  request: Request,
27
  audio: Annotated[UploadFile, File(...)],
@@ -44,14 +46,13 @@ async def transcribe_endpoint(
44
  detail="Unsupported audio format.",
45
  )
46
 
47
- audio_bytes = await audio.read()
48
- if not audio_bytes:
49
  raise HTTPException(
50
  status_code=status.HTTP_400_BAD_REQUEST,
51
  detail="Audio file is empty.",
52
  )
53
 
54
- if len(audio_bytes) > settings.TRANSCRIBE_MAX_UPLOAD_BYTES:
55
  raise HTTPException(
56
  status_code=status.HTTP_413_CONTENT_TOO_LARGE,
57
  detail="Audio file exceeds maximum allowed size.",
@@ -67,7 +68,7 @@ async def transcribe_endpoint(
67
  transcript = await transcriber.transcribe(
68
  filename=audio.filename or "audio.webm",
69
  content_type=content_type,
70
- audio_bytes=audio_bytes,
71
  language=language_code,
72
  )
73
 
 
4
 
5
  from app.models.speech import TranscribeResponse
6
  from app.security.jwt_auth import verify_jwt
7
+ from app.security.rate_limiter import transcribe_rate_limit
8
 
9
  router = APIRouter()
10
 
 
23
 
24
 
25
  @router.post("")
26
+ @transcribe_rate_limit()
27
  async def transcribe_endpoint(
28
  request: Request,
29
  audio: Annotated[UploadFile, File(...)],
 
46
  detail="Unsupported audio format.",
47
  )
48
 
49
+ if audio.size is None or audio.size == 0:
 
50
  raise HTTPException(
51
  status_code=status.HTTP_400_BAD_REQUEST,
52
  detail="Audio file is empty.",
53
  )
54
 
55
+ if audio.size > settings.TRANSCRIBE_MAX_UPLOAD_BYTES:
56
  raise HTTPException(
57
  status_code=status.HTTP_413_CONTENT_TOO_LARGE,
58
  detail="Audio file exceeds maximum allowed size.",
 
68
  transcript = await transcriber.transcribe(
69
  filename=audio.filename or "audio.webm",
70
  content_type=content_type,
71
+ audio_file=audio.file,
72
  language=language_code,
73
  )
74
 
app/api/tts.py CHANGED
@@ -1,20 +1,22 @@
1
  from typing import Annotated
2
 
3
  from fastapi import APIRouter, Depends, HTTPException, Request, status
4
- from fastapi.responses import Response
5
 
6
  from app.models.speech import SynthesizeRequest
7
  from app.security.jwt_auth import verify_jwt
 
8
 
9
  router = APIRouter()
10
 
11
 
12
  @router.post("")
 
13
  async def synthesize_endpoint(
14
  request: Request,
15
  payload: SynthesizeRequest,
16
  _: Annotated[dict, Depends(verify_jwt)],
17
- ) -> Response:
18
  tts_client = request.app.state.tts_client
19
  if not tts_client.is_configured:
20
  raise HTTPException(
@@ -22,8 +24,10 @@ async def synthesize_endpoint(
22
  detail="TTS service is not configured.",
23
  )
24
 
25
- audio_bytes = await tts_client.synthesize(
26
- payload.text.strip(),
27
- voice=payload.voice.strip().lower(),
 
 
 
28
  )
29
- return Response(content=audio_bytes, media_type="audio/wav")
 
1
  from typing import Annotated
2
 
3
  from fastapi import APIRouter, Depends, HTTPException, Request, status
4
+ from fastapi.responses import StreamingResponse
5
 
6
  from app.models.speech import SynthesizeRequest
7
  from app.security.jwt_auth import verify_jwt
8
+ from app.security.rate_limiter import tts_rate_limit
9
 
10
  router = APIRouter()
11
 
12
 
13
  @router.post("")
14
+ @tts_rate_limit()
15
  async def synthesize_endpoint(
16
  request: Request,
17
  payload: SynthesizeRequest,
18
  _: Annotated[dict, Depends(verify_jwt)],
19
+ ) -> StreamingResponse:
20
  tts_client = request.app.state.tts_client
21
  if not tts_client.is_configured:
22
  raise HTTPException(
 
24
  detail="TTS service is not configured.",
25
  )
26
 
27
+ return StreamingResponse(
28
+ tts_client.synthesize_stream(
29
+ payload.text.strip(),
30
+ voice=payload.voice.strip().lower(),
31
+ ),
32
+ media_type="audio/wav"
33
  )
 
app/core/config.py CHANGED
@@ -65,6 +65,14 @@ class Settings(BaseSettings):
65
  # Speech-to-text upload constraints
66
  TRANSCRIBE_MAX_UPLOAD_BYTES: int = 2 * 1024 * 1024
67
  TRANSCRIBE_TIMEOUT_SECONDS: float = 25.0
 
 
 
 
 
 
 
 
68
 
69
  model_config = SettingsConfigDict(env_file=".env", extra="ignore")
70
 
 
65
  # Speech-to-text upload constraints
66
  TRANSCRIBE_MAX_UPLOAD_BYTES: int = 2 * 1024 * 1024
67
  TRANSCRIBE_TIMEOUT_SECONDS: float = 25.0
68
+ TRANSCRIBE_DEFAULT_LANGUAGE: str = "en"
69
+ TRANSCRIBE_REPLACEMENTS: dict[str, str] = {
70
+ r"\bwalk experience\b": "work experience",
71
+ r"\btext stack\b": "tech stack",
72
+ r"\bprofessional sitting\b": "professional setting",
73
+ r"\btech stocks\b": "tech stack",
74
+ r"\bwhat tech stack does he\s+used\b": "what tech stack does he use",
75
+ }
76
 
77
  model_config = SettingsConfigDict(env_file=".env", extra="ignore")
78
 
app/main.py CHANGED
@@ -144,6 +144,8 @@ async def lifespan(app: FastAPI):
144
  api_key=settings.GROQ_API_KEY or "",
145
  model=settings.GROQ_TRANSCRIBE_MODEL,
146
  timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
 
 
147
  )
148
  app.state.tts_client = TTSClient(
149
  tts_space_url=settings.TTS_SPACE_URL,
 
144
  api_key=settings.GROQ_API_KEY or "",
145
  model=settings.GROQ_TRANSCRIBE_MODEL,
146
  timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
147
+ default_language=settings.TRANSCRIBE_DEFAULT_LANGUAGE,
148
+ replacements=settings.TRANSCRIBE_REPLACEMENTS,
149
  )
150
  app.state.tts_client = TTSClient(
151
  tts_space_url=settings.TTS_SPACE_URL,
app/security/rate_limiter.py CHANGED
@@ -18,3 +18,9 @@ async def custom_rate_limit_handler(request: Request, exc: Exception) -> JSONRes
18
  # Decorator factory chat_rate_limit that applies 20/minute limit.
19
  def chat_rate_limit() -> Callable:
20
  return limiter.limit("20/minute")
 
 
 
 
 
 
 
18
  # Decorator factory chat_rate_limit that applies 20/minute limit.
19
  def chat_rate_limit() -> Callable:
20
  return limiter.limit("20/minute")
21
+
22
+ def tts_rate_limit() -> Callable:
23
+ return limiter.limit("10/minute")
24
+
25
+ def transcribe_rate_limit() -> Callable:
26
+ return limiter.limit("20/minute")
app/services/transcriber.py CHANGED
@@ -9,19 +9,12 @@ from app.core.exceptions import GenerationError
9
 
10
  _FILLER_PREFIX_RE = re.compile(r"^\s*(uh+|um+|erm+|like|you know|please|hey)\s+", re.IGNORECASE)
11
  _MULTISPACE_RE = re.compile(r"\s+")
12
- _TRANSCRIPT_REPLACEMENTS: tuple[tuple[re.Pattern[str], str], ...] = (
13
- (re.compile(r"\bwalk experience\b", re.IGNORECASE), "work experience"),
14
- (re.compile(r"\btext stack\b", re.IGNORECASE), "tech stack"),
15
- (re.compile(r"\bprofessional sitting\b", re.IGNORECASE), "professional setting"),
16
- (re.compile(r"\btech stocks\b", re.IGNORECASE), "tech stack"),
17
- (re.compile(r"\bwhat tech stack does he\s+used\b", re.IGNORECASE), "what tech stack does he use"),
18
- )
19
 
20
 
21
- def _normalise_transcript_text(text: str) -> str:
22
  cleaned = text.strip()
23
  cleaned = _FILLER_PREFIX_RE.sub("", cleaned)
24
- for pattern, replacement in _TRANSCRIPT_REPLACEMENTS:
25
  cleaned = pattern.sub(replacement, cleaned)
26
  cleaned = _MULTISPACE_RE.sub(" ", cleaned)
27
  return cleaned.strip()
@@ -33,10 +26,17 @@ class GroqTranscriber:
33
  api_key: str,
34
  model: str,
35
  timeout_seconds: float,
 
 
36
  ) -> None:
37
  self._client = AsyncGroq(api_key=api_key) if api_key else None
38
  self._model = model
39
  self._timeout_seconds = timeout_seconds
 
 
 
 
 
40
 
41
  @property
42
  def is_configured(self) -> bool:
@@ -51,26 +51,28 @@ class GroqTranscriber:
51
  self,
52
  filename: str,
53
  content_type: str,
54
- audio_bytes: bytes,
55
  language: str | None = None,
56
  ) -> str:
57
  if not self._client:
58
  raise GenerationError("Transcriber is not configured with GROQ_API_KEY")
59
 
 
 
60
  async def _call() -> str:
61
  response = await self._client.audio.transcriptions.create(
62
- file=(filename, audio_bytes, content_type),
63
  model=self._model,
64
  temperature=0,
65
- language=language,
66
  )
67
  text = getattr(response, "text", None)
68
  if isinstance(text, str) and text.strip():
69
- return _normalise_transcript_text(text)
70
  if isinstance(response, dict):
71
  value = response.get("text")
72
  if isinstance(value, str) and value.strip():
73
- return _normalise_transcript_text(value)
74
  raise GenerationError("Transcription response did not contain text")
75
 
76
  try:
 
9
 
10
  _FILLER_PREFIX_RE = re.compile(r"^\s*(uh+|um+|erm+|like|you know|please|hey)\s+", re.IGNORECASE)
11
  _MULTISPACE_RE = re.compile(r"\s+")
 
 
 
 
 
 
 
12
 
13
 
14
+ def _normalise_transcript_text(text: str, replacements: tuple[tuple[re.Pattern[str], str], ...]) -> str:
15
  cleaned = text.strip()
16
  cleaned = _FILLER_PREFIX_RE.sub("", cleaned)
17
+ for pattern, replacement in replacements:
18
  cleaned = pattern.sub(replacement, cleaned)
19
  cleaned = _MULTISPACE_RE.sub(" ", cleaned)
20
  return cleaned.strip()
 
26
  api_key: str,
27
  model: str,
28
  timeout_seconds: float,
29
+ default_language: str = "en",
30
+ replacements: dict[str, str] | None = None,
31
  ) -> None:
32
  self._client = AsyncGroq(api_key=api_key) if api_key else None
33
  self._model = model
34
  self._timeout_seconds = timeout_seconds
35
+ self._default_language = default_language
36
+ self._replacements = tuple(
37
+ (re.compile(pattern, re.IGNORECASE), replacement)
38
+ for pattern, replacement in (replacements or {}).items()
39
+ )
40
 
41
  @property
42
  def is_configured(self) -> bool:
 
51
  self,
52
  filename: str,
53
  content_type: str,
54
+ audio_file,
55
  language: str | None = None,
56
  ) -> str:
57
  if not self._client:
58
  raise GenerationError("Transcriber is not configured with GROQ_API_KEY")
59
 
60
+ target_language = language if language else self._default_language
61
+
62
  async def _call() -> str:
63
  response = await self._client.audio.transcriptions.create(
64
+ file=(filename, audio_file, content_type),
65
  model=self._model,
66
  temperature=0,
67
+ language=target_language,
68
  )
69
  text = getattr(response, "text", None)
70
  if isinstance(text, str) and text.strip():
71
+ return _normalise_transcript_text(text, self._replacements)
72
  if isinstance(response, dict):
73
  value = response.get("text")
74
  if isinstance(value, str) and value.strip():
75
+ return _normalise_transcript_text(value, self._replacements)
76
  raise GenerationError("Transcription response did not contain text")
77
 
78
  try:
app/services/tts_client.py CHANGED
@@ -1,7 +1,5 @@
1
  import asyncio
2
-
3
  import httpx
4
-
5
  from app.core.exceptions import GenerationError
6
 
7
 
@@ -44,3 +42,44 @@ class TTSClient:
44
  raise
45
  except Exception as exc:
46
  raise GenerationError("TTS synthesis failed", context={"error": str(exc)}) from exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
 
2
  import httpx
 
3
  from app.core.exceptions import GenerationError
4
 
5
 
 
42
  raise
43
  except Exception as exc:
44
  raise GenerationError("TTS synthesis failed", context={"error": str(exc)}) from exc
45
+
46
+ async def synthesize_stream(self, text: str, voice: str = "am_adam"):
47
+ text = text.strip()
48
+ if not text:
49
+ raise GenerationError("TTS request text is empty")
50
+
51
+ loop = asyncio.get_running_loop()
52
+ queue = asyncio.Queue()
53
+
54
+ def _worker():
55
+ try:
56
+ generator = self._pipeline(text, voice=voice, speed=1, split_pattern=r'\n+')
57
+ for gs, ps, audio in generator:
58
+ if audio is not None:
59
+ import numpy as np
60
+ pcm_audio = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16).tobytes()
61
+ loop.call_soon_threadsafe(queue.put_nowait, pcm_audio)
62
+ except Exception as e:
63
+ loop.call_soon_threadsafe(queue.put_nowait, e)
64
+ finally:
65
+ loop.call_soon_threadsafe(queue.put_nowait, None)
66
+
67
+ import threading
68
+ thread = threading.Thread(target=_worker)
69
+ thread.start()
70
+
71
+ import struct
72
+ # 44-byte WAV header with 0xFFFFFFFF for sizes (streaming)
73
+ yield struct.pack('<4sI4s4sIHHIIHH4sI',
74
+ b'RIFF', 0xFFFFFFFF, b'WAVE',
75
+ b'fmt ', 16, 1, 1, 24000, 48000, 2, 16,
76
+ b'data', 0xFFFFFFFF
77
+ )
78
+
79
+ while True:
80
+ chunk = await queue.get()
81
+ if chunk is None:
82
+ break
83
+ if isinstance(chunk, Exception):
84
+ raise GenerationError("TTS synthesis stream failed", context={"error": str(chunk)}) from chunk
85
+ yield chunk
requirements.txt CHANGED
@@ -26,4 +26,6 @@ google-genai>=1.0.0
26
  # fastembed: powers BM25 sparse retrieval (Stage 2). Qdrant/bm25 vocabulary
27
  # downloads ~5 MB on first use then runs fully local — no GPU, no network at query time.
28
  fastembed>=0.3.6
29
- toon_format @ git+https://github.com/toon-format/toon-python.git
 
 
 
26
  # fastembed: powers BM25 sparse retrieval (Stage 2). Qdrant/bm25 vocabulary
27
  # downloads ~5 MB on first use then runs fully local — no GPU, no network at query time.
28
  fastembed>=0.3.6
29
+ toon_format @ git+https://github.com/toon-format/toon-python.git
30
+ kokoro>=0.9.0
31
+ soundfile>=0.13.0
test_stream.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from kokoro import KPipeline
3
+ import numpy as np
4
+ import struct
5
+ import time
6
+
7
+ async def main():
8
+ pipeline = KPipeline(lang_code='a')
9
+ generator = pipeline("Hello, this is a test of streaming audio from Kokoro.", voice="am_adam", speed=1, split_pattern=r'\n+')
10
+
11
+ with open("test_stream.wav", "wb") as f:
12
+ # Write WAV header
13
+ # chunk_size = 36 + data_size
14
+ header = struct.pack('<4sI4s4sIHHIIHH4sI',
15
+ b'RIFF', 0xFFFFFFFF, b'WAVE',
16
+ b'fmt ', 16, 1, 1, 24000, 48000, 2, 16,
17
+ b'data', 0xFFFFFFFF
18
+ )
19
+ f.write(header)
20
+
21
+ for gs, ps, audio in generator:
22
+ if audio is not None:
23
+ print("Got chunk:", len(audio))
24
+ pcm = (np.clip(audio, -1.0, 1.0) * 32767).astype(np.int16).tobytes()
25
+ f.write(pcm)
26
+
27
+ if __name__ == "__main__":
28
+ asyncio.run(main())
tests/test_speech_endpoints.py CHANGED
@@ -10,7 +10,7 @@ def test_transcribe_requires_auth(app_client):
10
 
11
 
12
  def test_transcribe_success(app_client, valid_token):
13
- async def fake_transcribe(filename, content_type, audio_bytes, language=None):
14
  await asyncio.sleep(0)
15
  return "hello from voice"
16
 
@@ -46,13 +46,13 @@ def test_tts_requires_auth(app_client):
46
  def test_tts_success(app_client, valid_token):
47
  captured: dict[str, str] = {}
48
 
49
- async def fake_synthesize(text, voice="am_adam"):
50
  await asyncio.sleep(0)
51
  captured["text"] = text
52
  captured["voice"] = voice
53
- return b"RIFF....fake"
54
 
55
- app_client.app.state.tts_client.synthesize = fake_synthesize
56
 
57
  response = app_client.post(
58
  "/tts",
@@ -62,6 +62,7 @@ def test_tts_success(app_client, valid_token):
62
 
63
  assert response.status_code == 200
64
  assert response.headers.get("content-type", "").startswith("audio/wav")
 
65
  assert response.content == b"RIFF....fake"
66
  assert captured["text"] == "Hello world"
67
  assert captured["voice"] == "am_adam"
@@ -70,12 +71,12 @@ def test_tts_success(app_client, valid_token):
70
  def test_tts_uses_provided_voice(app_client, valid_token):
71
  captured: dict[str, str] = {}
72
 
73
- async def fake_synthesize(text, voice="am_adam"):
74
  await asyncio.sleep(0)
75
  captured["voice"] = voice
76
- return b"RIFF....fake"
77
 
78
- app_client.app.state.tts_client.synthesize = fake_synthesize
79
 
80
  response = app_client.post(
81
  "/tts",
 
10
 
11
 
12
  def test_transcribe_success(app_client, valid_token):
13
+ async def fake_transcribe(filename, content_type, audio_file, language=None):
14
  await asyncio.sleep(0)
15
  return "hello from voice"
16
 
 
46
  def test_tts_success(app_client, valid_token):
47
  captured: dict[str, str] = {}
48
 
49
+ async def fake_synthesize_stream(text, voice="am_adam"):
50
  await asyncio.sleep(0)
51
  captured["text"] = text
52
  captured["voice"] = voice
53
+ yield b"RIFF....fake"
54
 
55
+ app_client.app.state.tts_client.synthesize_stream = fake_synthesize_stream
56
 
57
  response = app_client.post(
58
  "/tts",
 
62
 
63
  assert response.status_code == 200
64
  assert response.headers.get("content-type", "").startswith("audio/wav")
65
+ # StreamingResponse returns chunks, so response.content concatenates them
66
  assert response.content == b"RIFF....fake"
67
  assert captured["text"] == "Hello world"
68
  assert captured["voice"] == "am_adam"
 
71
  def test_tts_uses_provided_voice(app_client, valid_token):
72
  captured: dict[str, str] = {}
73
 
74
+ async def fake_synthesize_stream(text, voice="am_adam"):
75
  await asyncio.sleep(0)
76
  captured["voice"] = voice
77
+ yield b"RIFF....fake"
78
 
79
+ app_client.app.state.tts_client.synthesize_stream = fake_synthesize_stream
80
 
81
  response = app_client.post(
82
  "/tts",
tests/test_transcriber_normalization.py CHANGED
@@ -1,15 +1,29 @@
 
 
 
1
  from app.services.transcriber import _normalise_transcript_text
2
 
3
 
 
 
 
 
 
 
 
 
4
  def test_normalise_walk_experience_to_work_experience() -> None:
5
  query = "uh what is his walk experience in a professional setting"
6
- assert _normalise_transcript_text(query) == "what is his work experience in a professional setting"
 
7
 
8
 
9
  def test_normalise_text_stack_to_tech_stack() -> None:
10
- assert _normalise_transcript_text("what text stack does he use") == "what tech stack does he use"
 
11
 
12
 
13
  def test_keeps_clean_transcript_unchanged() -> None:
14
  original = "What technologies and skills does he work with?"
15
- assert _normalise_transcript_text(original) == original
 
 
1
+ import re
2
+
3
+ from app.core.config import get_settings
4
  from app.services.transcriber import _normalise_transcript_text
5
 
6
 
7
+ def _get_test_replacements():
8
+ replacements = get_settings().TRANSCRIBE_REPLACEMENTS
9
+ return tuple(
10
+ (re.compile(pattern, re.IGNORECASE), replacement)
11
+ for pattern, replacement in replacements.items()
12
+ )
13
+
14
+
15
  def test_normalise_walk_experience_to_work_experience() -> None:
16
  query = "uh what is his walk experience in a professional setting"
17
+ replacements = _get_test_replacements()
18
+ assert _normalise_transcript_text(query, replacements) == "what is his work experience in a professional setting"
19
 
20
 
21
  def test_normalise_text_stack_to_tech_stack() -> None:
22
+ replacements = _get_test_replacements()
23
+ assert _normalise_transcript_text("what text stack does he use", replacements) == "what tech stack does he use"
24
 
25
 
26
  def test_keeps_clean_transcript_unchanged() -> None:
27
  original = "What technologies and skills does he work with?"
28
+ replacements = _get_test_replacements()
29
+ assert _normalise_transcript_text(original, replacements) == original