GitHub Actions commited on
Commit
815b978
·
1 Parent(s): 9563e4a

Deploy a45bfc7

Browse files
app/api/transcribe.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+
3
+ from fastapi import APIRouter, Depends, File, Form, HTTPException, Request, UploadFile, status
4
+
5
+ from app.models.speech import TranscribeResponse
6
+ from app.security.jwt_auth import verify_jwt
7
+
8
+ router = APIRouter()
9
+
10
+ _ALLOWED_AUDIO_TYPES: frozenset[str] = frozenset(
11
+ {
12
+ "audio/webm",
13
+ "audio/wav",
14
+ "audio/x-wav",
15
+ "audio/mpeg",
16
+ "audio/mp3",
17
+ "audio/mp4",
18
+ "audio/ogg",
19
+ "audio/flac",
20
+ }
21
+ )
22
+
23
+
24
+ @router.post("")
25
+ async def transcribe_endpoint(
26
+ request: Request,
27
+ audio: Annotated[UploadFile, File(...)],
28
+ _: Annotated[dict, Depends(verify_jwt)],
29
+ language: Annotated[str | None, Form()] = None,
30
+ ) -> TranscribeResponse:
31
+ settings = request.app.state.settings
32
+ transcriber = request.app.state.transcriber
33
+
34
+ if not transcriber.is_configured:
35
+ raise HTTPException(
36
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
37
+ detail="Transcription service is not configured.",
38
+ )
39
+
40
+ content_type = (audio.content_type or "").strip().lower()
41
+ if content_type not in _ALLOWED_AUDIO_TYPES:
42
+ raise HTTPException(
43
+ status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
44
+ detail="Unsupported audio format.",
45
+ )
46
+
47
+ audio_bytes = await audio.read()
48
+ if not audio_bytes:
49
+ raise HTTPException(
50
+ status_code=status.HTTP_400_BAD_REQUEST,
51
+ detail="Audio file is empty.",
52
+ )
53
+
54
+ if len(audio_bytes) > settings.TRANSCRIBE_MAX_UPLOAD_BYTES:
55
+ raise HTTPException(
56
+ status_code=status.HTTP_413_CONTENT_TOO_LARGE,
57
+ detail="Audio file exceeds maximum allowed size.",
58
+ )
59
+
60
+ language_code = language.strip().lower() if language and language.strip() else None
61
+ if language_code and len(language_code) > 10:
62
+ raise HTTPException(
63
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
64
+ detail="Invalid language code.",
65
+ )
66
+
67
+ transcript = await transcriber.transcribe(
68
+ filename=audio.filename or "audio.webm",
69
+ content_type=content_type,
70
+ audio_bytes=audio_bytes,
71
+ language=language_code,
72
+ )
73
+
74
+ return TranscribeResponse(transcript=transcript)
app/api/tts.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+
3
+ from fastapi import APIRouter, Depends, HTTPException, Request, status
4
+ from fastapi.responses import Response
5
+
6
+ from app.models.speech import SynthesizeRequest
7
+ from app.security.jwt_auth import verify_jwt
8
+
9
+ router = APIRouter()
10
+
11
+
12
+ @router.post("")
13
+ async def synthesize_endpoint(
14
+ request: Request,
15
+ payload: SynthesizeRequest,
16
+ _: Annotated[dict, Depends(verify_jwt)],
17
+ ) -> Response:
18
+ tts_client = request.app.state.tts_client
19
+ if not tts_client.is_configured:
20
+ raise HTTPException(
21
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
22
+ detail="TTS service is not configured.",
23
+ )
24
+
25
+ audio_bytes = await tts_client.synthesize(payload.text.strip())
26
+ return Response(content=audio_bytes, media_type="audio/wav")
app/core/config.py CHANGED
@@ -12,6 +12,7 @@ class Settings(BaseSettings):
12
  OLLAMA_MODEL: Optional[str] = None
13
  GROQ_MODEL_DEFAULT: str = "llama-3.1-8b-instant"
14
  GROQ_MODEL_LARGE: str = "llama-3.3-70b-versatile"
 
15
 
16
  # Vector
17
  QDRANT_URL: str
@@ -67,6 +68,11 @@ class Settings(BaseSettings):
67
  # In prod, the API Space calls the HF embedder/reranker Spaces via HTTP.
68
  EMBEDDER_URL: str = "http://localhost:7860"
69
  RERANKER_URL: str = "http://localhost:7861"
 
 
 
 
 
70
 
71
  model_config = SettingsConfigDict(env_file=".env", extra="ignore")
72
 
 
12
  OLLAMA_MODEL: Optional[str] = None
13
  GROQ_MODEL_DEFAULT: str = "llama-3.1-8b-instant"
14
  GROQ_MODEL_LARGE: str = "llama-3.3-70b-versatile"
15
+ GROQ_TRANSCRIBE_MODEL: str = "whisper-large-v3-turbo"
16
 
17
  # Vector
18
  QDRANT_URL: str
 
68
  # In prod, the API Space calls the HF embedder/reranker Spaces via HTTP.
69
  EMBEDDER_URL: str = "http://localhost:7860"
70
  RERANKER_URL: str = "http://localhost:7861"
71
+ TTS_SPACE_URL: str = "http://localhost:7862"
72
+
73
+ # Speech-to-text upload constraints
74
+ TRANSCRIBE_MAX_UPLOAD_BYTES: int = 2 * 1024 * 1024
75
+ TRANSCRIBE_TIMEOUT_SECONDS: float = 25.0
76
 
77
  model_config = SettingsConfigDict(env_file=".env", extra="ignore")
78
 
app/main.py CHANGED
@@ -14,6 +14,8 @@ from app.api.admin import router as admin_router
14
  from app.api.chat import router as chat_router
15
  from app.api.feedback import router as feedback_router
16
  from app.api.health import router as health_router
 
 
17
  from app.core.config import get_settings
18
  from app.core.exceptions import AppError
19
  from app.core.logging import get_logger
@@ -25,6 +27,8 @@ from app.services.github_log import GithubLog
25
  from app.services.llm_client import get_llm_client, TpmBucket
26
  from app.services.reranker import Reranker
27
  from app.services.semantic_cache import SemanticCache
 
 
28
  from app.services.conversation_store import ConversationStore
29
  from qdrant_client import QdrantClient
30
 
@@ -156,6 +160,15 @@ async def lifespan(app: FastAPI):
156
  context_path=settings.GEMINI_CONTEXT_PATH,
157
  )
158
  app.state.gemini_client = gemini_client
 
 
 
 
 
 
 
 
 
159
 
160
  from app.services.vector_store import VectorStore
161
  from app.security.guard_classifier import GuardClassifier
@@ -286,6 +299,8 @@ def create_app() -> FastAPI:
286
 
287
  app.include_router(health_router, tags=["Health"])
288
  app.include_router(chat_router, prefix="/chat", tags=["Chat"])
 
 
289
  app.include_router(feedback_router, prefix="/chat", tags=["Feedback"])
290
  app.include_router(admin_router, prefix="/admin", tags=["Admin"])
291
 
 
14
  from app.api.chat import router as chat_router
15
  from app.api.feedback import router as feedback_router
16
  from app.api.health import router as health_router
17
+ from app.api.tts import router as tts_router
18
+ from app.api.transcribe import router as transcribe_router
19
  from app.core.config import get_settings
20
  from app.core.exceptions import AppError
21
  from app.core.logging import get_logger
 
27
  from app.services.llm_client import get_llm_client, TpmBucket
28
  from app.services.reranker import Reranker
29
  from app.services.semantic_cache import SemanticCache
30
+ from app.services.transcriber import GroqTranscriber
31
+ from app.services.tts_client import TTSClient
32
  from app.services.conversation_store import ConversationStore
33
  from qdrant_client import QdrantClient
34
 
 
160
  context_path=settings.GEMINI_CONTEXT_PATH,
161
  )
162
  app.state.gemini_client = gemini_client
163
+ app.state.transcriber = GroqTranscriber(
164
+ api_key=settings.GROQ_API_KEY or "",
165
+ model=settings.GROQ_TRANSCRIBE_MODEL,
166
+ timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
167
+ )
168
+ app.state.tts_client = TTSClient(
169
+ tts_space_url=settings.TTS_SPACE_URL,
170
+ timeout_seconds=settings.TRANSCRIBE_TIMEOUT_SECONDS,
171
+ )
172
 
173
  from app.services.vector_store import VectorStore
174
  from app.security.guard_classifier import GuardClassifier
 
299
 
300
  app.include_router(health_router, tags=["Health"])
301
  app.include_router(chat_router, prefix="/chat", tags=["Chat"])
302
+ app.include_router(transcribe_router, prefix="/transcribe", tags=["Transcribe"])
303
+ app.include_router(tts_router, prefix="/tts", tags=["TTS"])
304
  app.include_router(feedback_router, prefix="/chat", tags=["Feedback"])
305
  app.include_router(admin_router, prefix="/admin", tags=["Admin"])
306
 
app/models/speech.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+
3
+
4
+ class TranscribeResponse(BaseModel):
5
+ transcript: str = Field(..., min_length=1, max_length=5000)
6
+
7
+
8
+ class SynthesizeRequest(BaseModel):
9
+ text: str = Field(..., min_length=1, max_length=300)
app/services/transcriber.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import httpx
4
+ from groq import AsyncGroq
5
+ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
6
+
7
+ from app.core.exceptions import GenerationError
8
+
9
+
10
+ class GroqTranscriber:
11
+ def __init__(
12
+ self,
13
+ api_key: str,
14
+ model: str,
15
+ timeout_seconds: float,
16
+ ) -> None:
17
+ self._client = AsyncGroq(api_key=api_key) if api_key else None
18
+ self._model = model
19
+ self._timeout_seconds = timeout_seconds
20
+
21
+ @property
22
+ def is_configured(self) -> bool:
23
+ return self._client is not None
24
+
25
+ @retry(
26
+ stop=stop_after_attempt(2),
27
+ wait=wait_fixed(0.8),
28
+ retry=retry_if_exception_type((httpx.RequestError, httpx.TimeoutException)),
29
+ )
30
+ async def transcribe(
31
+ self,
32
+ filename: str,
33
+ content_type: str,
34
+ audio_bytes: bytes,
35
+ language: str | None = None,
36
+ ) -> str:
37
+ if not self._client:
38
+ raise GenerationError("Transcriber is not configured with GROQ_API_KEY")
39
+
40
+ async def _call() -> str:
41
+ response = await self._client.audio.transcriptions.create(
42
+ file=(filename, audio_bytes, content_type),
43
+ model=self._model,
44
+ temperature=0,
45
+ language=language,
46
+ )
47
+ text = getattr(response, "text", None)
48
+ if isinstance(text, str) and text.strip():
49
+ return text.strip()
50
+ if isinstance(response, dict):
51
+ value = response.get("text")
52
+ if isinstance(value, str) and value.strip():
53
+ return value.strip()
54
+ raise GenerationError("Transcription response did not contain text")
55
+
56
+ try:
57
+ return await asyncio.wait_for(_call(), timeout=self._timeout_seconds)
58
+ except TimeoutError as exc:
59
+ raise GenerationError("Transcription timed out") from exc
60
+ except GenerationError:
61
+ raise
62
+ except Exception as exc:
63
+ raise GenerationError("Transcription failed", context={"error": str(exc)}) from exc
app/services/tts_client.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+ import httpx
4
+
5
+ from app.core.exceptions import GenerationError
6
+
7
+
8
+ class TTSClient:
9
+ def __init__(self, tts_space_url: str, timeout_seconds: float) -> None:
10
+ self._tts_space_url = tts_space_url.rstrip("/")
11
+ self._timeout_seconds = timeout_seconds
12
+
13
+ @property
14
+ def is_configured(self) -> bool:
15
+ return bool(self._tts_space_url)
16
+
17
+ async def synthesize(self, text: str) -> bytes:
18
+ if not self.is_configured:
19
+ raise GenerationError("TTS client is not configured")
20
+
21
+ async def _call() -> bytes:
22
+ async with httpx.AsyncClient(timeout=self._timeout_seconds) as client:
23
+ response = await client.post(
24
+ f"{self._tts_space_url}/synthesize",
25
+ json={"text": text},
26
+ headers={"Content-Type": "application/json"},
27
+ )
28
+ response.raise_for_status()
29
+ audio_bytes = response.content
30
+ if not audio_bytes:
31
+ raise GenerationError("TTS response was empty")
32
+ return audio_bytes
33
+
34
+ try:
35
+ return await asyncio.wait_for(_call(), timeout=self._timeout_seconds)
36
+ except TimeoutError as exc:
37
+ raise GenerationError("TTS request timed out") from exc
38
+ except httpx.HTTPStatusError as exc:
39
+ raise GenerationError(
40
+ "TTS upstream returned an error",
41
+ context={"status_code": exc.response.status_code},
42
+ ) from exc
43
+ except GenerationError:
44
+ raise
45
+ except Exception as exc:
46
+ raise GenerationError("TTS synthesis failed", context={"error": str(exc)}) from exc
requirements.txt CHANGED
@@ -9,6 +9,7 @@
9
  fastapi>=0.115.0
10
  uvicorn[standard]>=0.29.0
11
  uvloop>=0.19.0
 
12
  pydantic-settings>=2.2.1
13
  langgraph>=0.2.0
14
  qdrant-client==1.9.1
 
9
  fastapi>=0.115.0
10
  uvicorn[standard]>=0.29.0
11
  uvloop>=0.19.0
12
+ python-multipart>=0.0.9
13
  pydantic-settings>=2.2.1
14
  langgraph>=0.2.0
15
  qdrant-client==1.9.1
tests/test_speech_endpoints.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+
3
+
4
+ def test_transcribe_requires_auth(app_client):
5
+ response = app_client.post(
6
+ "/transcribe",
7
+ files={"audio": ("sample.webm", b"abc", "audio/webm")},
8
+ )
9
+ assert response.status_code == 401
10
+
11
+
12
+ def test_transcribe_success(app_client, valid_token):
13
+ async def fake_transcribe(filename, content_type, audio_bytes, language=None):
14
+ await asyncio.sleep(0)
15
+ return "hello from voice"
16
+
17
+ app_client.app.state.transcriber.transcribe = fake_transcribe
18
+
19
+ response = app_client.post(
20
+ "/transcribe",
21
+ files={"audio": ("sample.webm", b"abc", "audio/webm")},
22
+ headers={"Authorization": f"Bearer {valid_token}"},
23
+ )
24
+
25
+ assert response.status_code == 200
26
+ assert response.json()["transcript"] == "hello from voice"
27
+
28
+
29
+ def test_transcribe_rejects_oversized_audio(app_client, valid_token):
30
+ app_client.app.state.settings.TRANSCRIBE_MAX_UPLOAD_BYTES = 2
31
+
32
+ response = app_client.post(
33
+ "/transcribe",
34
+ files={"audio": ("sample.webm", b"abcdef", "audio/webm")},
35
+ headers={"Authorization": f"Bearer {valid_token}"},
36
+ )
37
+
38
+ assert response.status_code == 413
39
+
40
+
41
+ def test_tts_requires_auth(app_client):
42
+ response = app_client.post("/tts", json={"text": "Hello world"})
43
+ assert response.status_code == 401
44
+
45
+
46
+ def test_tts_success(app_client, valid_token):
47
+ async def fake_synthesize(text):
48
+ await asyncio.sleep(0)
49
+ return b"RIFF....fake"
50
+
51
+ app_client.app.state.tts_client.synthesize = fake_synthesize
52
+
53
+ response = app_client.post(
54
+ "/tts",
55
+ json={"text": "Hello world"},
56
+ headers={"Authorization": f"Bearer {valid_token}"},
57
+ )
58
+
59
+ assert response.status_code == 200
60
+ assert response.headers.get("content-type", "").startswith("audio/wav")
61
+ assert response.content == b"RIFF....fake"