Spaces:
Sleeping
Sleeping
lordofgaming commited on
Commit ·
673435a
1
Parent(s): fbdfd83
Initial VoiceForge deployment (clean)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +39 -0
- README.md +20 -6
- backend/.env +16 -0
- backend/.flake8 +4 -0
- backend/Dockerfile +50 -0
- backend/app/__init__.py +3 -0
- backend/app/api/__init__.py +3 -0
- backend/app/api/routes/__init__.py +35 -0
- backend/app/api/routes/analysis.py +60 -0
- backend/app/api/routes/audio.py +100 -0
- backend/app/api/routes/auth.py +116 -0
- backend/app/api/routes/batch.py +204 -0
- backend/app/api/routes/cloning.py +81 -0
- backend/app/api/routes/health.py +93 -0
- backend/app/api/routes/s2s.py +45 -0
- backend/app/api/routes/sign.py +164 -0
- backend/app/api/routes/sign_bridge.py +63 -0
- backend/app/api/routes/stt.py +489 -0
- backend/app/api/routes/transcripts.py +200 -0
- backend/app/api/routes/translation.py +261 -0
- backend/app/api/routes/tts.py +245 -0
- backend/app/api/routes/ws.py +208 -0
- backend/app/core/__init__.py +7 -0
- backend/app/core/config.py +108 -0
- backend/app/core/limiter.py +27 -0
- backend/app/core/middleware.py +70 -0
- backend/app/core/request_size_middleware.py +91 -0
- backend/app/core/security.py +113 -0
- backend/app/core/security_encryption.py +107 -0
- backend/app/core/security_headers.py +37 -0
- backend/app/core/ws_security.py +181 -0
- backend/app/main.py +273 -0
- backend/app/models/__init__.py +19 -0
- backend/app/models/audio_file.py +47 -0
- backend/app/models/auth.py +36 -0
- backend/app/models/base.py +44 -0
- backend/app/models/sign_lstm.py +63 -0
- backend/app/models/transcript.py +67 -0
- backend/app/schemas/__init__.py +39 -0
- backend/app/schemas/stt.py +98 -0
- backend/app/schemas/transcript.py +69 -0
- backend/app/schemas/tts.py +67 -0
- backend/app/services/__init__.py +13 -0
- backend/app/services/audio_service.py +101 -0
- backend/app/services/batch_service.py +348 -0
- backend/app/services/cache_service.py +71 -0
- backend/app/services/clone_service.py +104 -0
- backend/app/services/diarization_service.py +338 -0
- backend/app/services/edge_tts_service.py +357 -0
- backend/app/services/emotion_service.py +132 -0
Dockerfile
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CPU-only base for free-tier HuggingFace Spaces
|
| 2 |
+
FROM python:3.10-slim
|
| 3 |
+
|
| 4 |
+
# Set environment variables
|
| 5 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 6 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 7 |
+
DEBIAN_FRONTEND=noninteractive
|
| 8 |
+
|
| 9 |
+
# Install system dependencies
|
| 10 |
+
RUN apt-get update && apt-get install -y \
|
| 11 |
+
python3-pip \
|
| 12 |
+
python3-dev \
|
| 13 |
+
ffmpeg \
|
| 14 |
+
libsndfile1 \
|
| 15 |
+
git \
|
| 16 |
+
supervisor \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Set working directory
|
| 20 |
+
WORKDIR /app
|
| 21 |
+
|
| 22 |
+
# Copy requirements and install
|
| 23 |
+
COPY deploy/huggingface/requirements.txt .
|
| 24 |
+
RUN pip3 install --no-cache-dir -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# Copy the rest of the application
|
| 27 |
+
COPY . .
|
| 28 |
+
|
| 29 |
+
# Create directory for weights and logs
|
| 30 |
+
RUN mkdir -p backend/app/models/weights /var/log/supervisor
|
| 31 |
+
|
| 32 |
+
# Copy supervisor config
|
| 33 |
+
COPY deploy/huggingface/supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
| 34 |
+
|
| 35 |
+
# Expose ports (HF expects the app on 7860)
|
| 36 |
+
EXPOSE 7860
|
| 37 |
+
|
| 38 |
+
# Run supervisor to start both backend and frontend
|
| 39 |
+
CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
|
README.md
CHANGED
|
@@ -1,10 +1,24 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
pinned:
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: VoiceForge Universal
|
| 3 |
+
emoji: 🎙️
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: pink
|
| 6 |
sdk: docker
|
| 7 |
+
pinned: true
|
| 8 |
+
license: mit
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# VoiceForge: Universal Communication Platform
|
| 12 |
+
|
| 13 |
+
**Instant Speech-to-Speech | Signed Communication | Meeting Intelligence**
|
| 14 |
+
|
| 15 |
+
Powered by:
|
| 16 |
+
- **Whisper** (STT)
|
| 17 |
+
- **SeamlessM4T** (S2S)
|
| 18 |
+
- **MediaPipe** (Sign Recognition)
|
| 19 |
+
- **Edge TTS** (Synthesis)
|
| 20 |
+
|
| 21 |
+
## 🚀 Usage
|
| 22 |
+
1. Click the "Speech-to-Speech" tab to translate voice instantly.
|
| 23 |
+
2. Use "Signed Communication" to visualize ASL.
|
| 24 |
+
3. Upload meetings for instant minutes.
|
backend/.env
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# VoiceForge Backend Environment Variables
|
| 2 |
+
|
| 3 |
+
# Database
|
| 4 |
+
DATABASE_URL=sqlite:///./voiceforge.db
|
| 5 |
+
|
| 6 |
+
# Hugging Face Token (for Speaker Diarization)
|
| 7 |
+
# Get your token at: https://huggingface.co/settings/tokens
|
| 8 |
+
# See docs/HF_TOKEN_ROTATION.md for setup instructions
|
| 9 |
+
HF_TOKEN=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
|
| 10 |
+
|
| 11 |
+
# Encryption Key (auto-generated in dev, REQUIRED in production)
|
| 12 |
+
# Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
|
| 13 |
+
# ENCRYPTION_KEY=
|
| 14 |
+
|
| 15 |
+
# Environment (development | production)
|
| 16 |
+
ENVIRONMENT=development
|
backend/.flake8
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[flake8]
|
| 2 |
+
max-line-length = 120
|
| 3 |
+
extend-ignore = E203
|
| 4 |
+
exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,venv
|
backend/Dockerfile
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Build Stage
|
| 2 |
+
FROM python:3.10-slim as builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
# Set environment variables
|
| 7 |
+
ENV PYTHONDONTWRITEBYTECODE 1
|
| 8 |
+
ENV PYTHONUNBUFFERED 1
|
| 9 |
+
|
| 10 |
+
# Install system dependencies required for building python packages
|
| 11 |
+
# ffmpeg is needed for audio processing
|
| 12 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
gcc \
|
| 14 |
+
ffmpeg \
|
| 15 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 16 |
+
|
| 17 |
+
# Install python dependencies
|
| 18 |
+
COPY requirements.txt .
|
| 19 |
+
RUN pip wheel --no-cache-dir --no-deps --wheel-dir /app/wheels -r requirements.txt
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Final Stage
|
| 23 |
+
FROM python:3.10-slim
|
| 24 |
+
|
| 25 |
+
WORKDIR /app
|
| 26 |
+
|
| 27 |
+
# Install runtime dependencies (ffmpeg)
|
| 28 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 29 |
+
ffmpeg \
|
| 30 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 31 |
+
|
| 32 |
+
# Copy wheels from builder
|
| 33 |
+
COPY --from=builder /app/wheels /wheels
|
| 34 |
+
COPY --from=builder /app/requirements.txt .
|
| 35 |
+
|
| 36 |
+
# Install dependencies from wheels
|
| 37 |
+
RUN pip install --no-cache /wheels/*
|
| 38 |
+
|
| 39 |
+
# Copy application code
|
| 40 |
+
COPY . .
|
| 41 |
+
|
| 42 |
+
# Create a non-root user
|
| 43 |
+
RUN addgroup --system app && adduser --system --group app
|
| 44 |
+
USER app
|
| 45 |
+
|
| 46 |
+
# Expose port
|
| 47 |
+
EXPOSE 8000
|
| 48 |
+
|
| 49 |
+
# Run commands
|
| 50 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
backend/app/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge Backend Package
|
| 3 |
+
"""
|
backend/app/api/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge API Package
|
| 3 |
+
"""
|
backend/app/api/routes/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge API Routes Package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .stt import router as stt_router
|
| 6 |
+
from .tts import router as tts_router
|
| 7 |
+
from .health import router as health_router
|
| 8 |
+
from .transcripts import router as transcripts_router
|
| 9 |
+
from .ws import router as ws_router
|
| 10 |
+
from .translation import router as translation_router
|
| 11 |
+
from .batch import router as batch_router
|
| 12 |
+
from .analysis import router as analysis_router
|
| 13 |
+
from .audio import router as audio_router
|
| 14 |
+
from .cloning import router as cloning_router
|
| 15 |
+
from .sign import router as sign_router
|
| 16 |
+
from .auth import router as auth_router
|
| 17 |
+
from .s2s import router as s2s_router
|
| 18 |
+
from .sign_bridge import router as sign_bridge_router
|
| 19 |
+
|
| 20 |
+
__all__ = [
|
| 21 |
+
"stt_router",
|
| 22 |
+
"tts_router",
|
| 23 |
+
"health_router",
|
| 24 |
+
"transcripts_router",
|
| 25 |
+
"ws_router",
|
| 26 |
+
"translation_router",
|
| 27 |
+
"batch_router",
|
| 28 |
+
"analysis_router",
|
| 29 |
+
"audio_router",
|
| 30 |
+
"cloning_router",
|
| 31 |
+
"sign_router",
|
| 32 |
+
"auth_router",
|
| 33 |
+
"s2s_router",
|
| 34 |
+
"sign_bridge_router",
|
| 35 |
+
]
|
backend/app/api/routes/analysis.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Analysis API Routes
|
| 3 |
+
Endpoints for Emotion and Sentiment Analysis
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import shutil
|
| 11 |
+
import tempfile
|
| 12 |
+
|
| 13 |
+
from app.services.emotion_service import get_emotion_service
|
| 14 |
+
from app.services.nlp_service import get_nlp_service
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
router = APIRouter(prefix="/analysis", tags=["Analysis"])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.post("/emotion/audio")
|
| 21 |
+
async def analyze_audio_emotion(
|
| 22 |
+
file: UploadFile = File(..., description="Audio file to analyze"),
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Analyze emotions in an audio file using Wav2Vec2.
|
| 26 |
+
Returns dominant emotion and probability distribution.
|
| 27 |
+
"""
|
| 28 |
+
service = get_emotion_service()
|
| 29 |
+
|
| 30 |
+
# Save to temp file
|
| 31 |
+
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 32 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 33 |
+
shutil.copyfileobj(file.file, tmp)
|
| 34 |
+
tmp_path = tmp.name
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
result = service.analyze_audio(tmp_path)
|
| 38 |
+
return result
|
| 39 |
+
except Exception as e:
|
| 40 |
+
logger.error(f"Emotion analysis failed: {e}")
|
| 41 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 42 |
+
finally:
|
| 43 |
+
try:
|
| 44 |
+
os.unlink(tmp_path)
|
| 45 |
+
except:
|
| 46 |
+
pass
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@router.post("/sentiment/text")
|
| 50 |
+
async def analyze_text_sentiment(
|
| 51 |
+
text: str = Form(..., description="Text to analyze"),
|
| 52 |
+
):
|
| 53 |
+
"""
|
| 54 |
+
Analyze text sentiment (polarity and subjectivity).
|
| 55 |
+
"""
|
| 56 |
+
nlp_service = get_nlp_service()
|
| 57 |
+
try:
|
| 58 |
+
return nlp_service.analyze_sentiment(text)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/app/api/routes/audio.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Editing API Routes
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
|
| 6 |
+
from fastapi.responses import FileResponse
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import tempfile
|
| 11 |
+
import uuid
|
| 12 |
+
|
| 13 |
+
from app.services.audio_service import get_audio_service, AudioService
|
| 14 |
+
|
| 15 |
+
router = APIRouter(prefix="/audio", tags=["Audio Studio"])
|
| 16 |
+
|
| 17 |
+
@router.post("/trim")
|
| 18 |
+
async def trim_audio(
|
| 19 |
+
file: UploadFile = File(..., description="Audio file"),
|
| 20 |
+
start_sec: float = Form(..., description="Start time in seconds"),
|
| 21 |
+
end_sec: float = Form(..., description="End time in seconds"),
|
| 22 |
+
service: AudioService = Depends(get_audio_service)
|
| 23 |
+
):
|
| 24 |
+
"""Trim an audio file"""
|
| 25 |
+
suffix = os.path.splitext(file.filename)[1] or ".mp3"
|
| 26 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 27 |
+
shutil.copyfileobj(file.file, tmp)
|
| 28 |
+
tmp_path = tmp.name
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
output_path = tmp_path.replace(suffix, f"_trimmed{suffix}")
|
| 32 |
+
service.trim_audio(tmp_path, int(start_sec * 1000), int(end_sec * 1000), output_path)
|
| 33 |
+
|
| 34 |
+
return FileResponse(
|
| 35 |
+
output_path,
|
| 36 |
+
filename=f"trimmed_{file.filename}",
|
| 37 |
+
background=None # Let FastAPI handle cleanup? No, we need custom cleanup or use background task
|
| 38 |
+
)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 41 |
+
# Note: Temp files might persist. In prod, use a cleanup task.
|
| 42 |
+
|
| 43 |
+
@router.post("/merge")
|
| 44 |
+
async def merge_audio(
|
| 45 |
+
files: List[UploadFile] = File(..., description="Files to merge"),
|
| 46 |
+
format: str = Form("mp3", description="Output format"),
|
| 47 |
+
service: AudioService = Depends(get_audio_service)
|
| 48 |
+
):
|
| 49 |
+
"""Merge multiple audio files"""
|
| 50 |
+
temp_files = []
|
| 51 |
+
try:
|
| 52 |
+
for file in files:
|
| 53 |
+
suffix = os.path.splitext(file.filename)[1] or ".mp3"
|
| 54 |
+
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
| 55 |
+
content = await file.read()
|
| 56 |
+
tmp.write(content)
|
| 57 |
+
tmp.close()
|
| 58 |
+
temp_files.append(tmp.name)
|
| 59 |
+
|
| 60 |
+
output_filename = f"merged_{uuid.uuid4()}.{format}"
|
| 61 |
+
output_path = os.path.join(tempfile.gettempdir(), output_filename)
|
| 62 |
+
|
| 63 |
+
service.merge_audio(temp_files, output_path)
|
| 64 |
+
|
| 65 |
+
return FileResponse(output_path, filename=output_filename)
|
| 66 |
+
|
| 67 |
+
except Exception as e:
|
| 68 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 69 |
+
finally:
|
| 70 |
+
for p in temp_files:
|
| 71 |
+
try:
|
| 72 |
+
os.unlink(p)
|
| 73 |
+
except:
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
@router.post("/convert")
|
| 77 |
+
async def convert_audio(
|
| 78 |
+
file: UploadFile = File(..., description="Audio file"),
|
| 79 |
+
target_format: str = Form(..., description="Target format (mp3, wav, flac, ogg)"),
|
| 80 |
+
service: AudioService = Depends(get_audio_service)
|
| 81 |
+
):
|
| 82 |
+
"""Convert audio format"""
|
| 83 |
+
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 84 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 85 |
+
shutil.copyfileobj(file.file, tmp)
|
| 86 |
+
tmp_path = tmp.name
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
output_path = service.convert_format(tmp_path, target_format)
|
| 90 |
+
return FileResponse(
|
| 91 |
+
output_path,
|
| 92 |
+
filename=f"{os.path.splitext(file.filename)[0]}.{target_format}"
|
| 93 |
+
)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 96 |
+
finally:
|
| 97 |
+
try:
|
| 98 |
+
os.unlink(tmp_path)
|
| 99 |
+
except:
|
| 100 |
+
pass
|
backend/app/api/routes/auth.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta
|
| 2 |
+
from typing import List
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
import secrets
|
| 5 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 6 |
+
from fastapi.security import OAuth2PasswordRequestForm
|
| 7 |
+
from sqlalchemy.orm import Session
|
| 8 |
+
|
| 9 |
+
from ...core.security import (
|
| 10 |
+
create_access_token,
|
| 11 |
+
get_password_hash,
|
| 12 |
+
verify_password,
|
| 13 |
+
get_current_active_user,
|
| 14 |
+
ACCESS_TOKEN_EXPIRE_MINUTES
|
| 15 |
+
)
|
| 16 |
+
from ...models import get_db, User, ApiKey
|
| 17 |
+
from ...core.limiter import limiter
|
| 18 |
+
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
| 19 |
+
|
| 20 |
+
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
| 21 |
+
|
| 22 |
+
# --- Schemas ---
|
| 23 |
+
class Token(BaseModel):
|
| 24 |
+
access_token: str
|
| 25 |
+
token_type: str
|
| 26 |
+
|
| 27 |
+
class UserCreate(BaseModel):
|
| 28 |
+
email: str
|
| 29 |
+
password: str
|
| 30 |
+
full_name: str = None
|
| 31 |
+
|
| 32 |
+
class UserOut(BaseModel):
|
| 33 |
+
id: int
|
| 34 |
+
email: str
|
| 35 |
+
full_name: str = None
|
| 36 |
+
is_active: bool
|
| 37 |
+
|
| 38 |
+
class Config:
|
| 39 |
+
orm_mode = True
|
| 40 |
+
|
| 41 |
+
class ApiKeyCreate(BaseModel):
|
| 42 |
+
name: str
|
| 43 |
+
|
| 44 |
+
class ApiKeyOut(BaseModel):
|
| 45 |
+
key: str
|
| 46 |
+
name: str
|
| 47 |
+
created_at: datetime
|
| 48 |
+
|
| 49 |
+
class Config:
|
| 50 |
+
orm_mode = True
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# --- Endpoints ---
|
| 54 |
+
|
| 55 |
+
@router.post("/register", response_model=UserOut)
|
| 56 |
+
@limiter.limit("5/minute")
|
| 57 |
+
async def register(request: Request, user_in: UserCreate, db: Session = Depends(get_db)):
|
| 58 |
+
"""Register a new user"""
|
| 59 |
+
existing_user = db.query(User).filter(User.email == user_in.email).first()
|
| 60 |
+
if existing_user:
|
| 61 |
+
raise HTTPException(status_code=400, detail="Email already registered")
|
| 62 |
+
|
| 63 |
+
hashed_password = get_password_hash(user_in.password)
|
| 64 |
+
new_user = User(
|
| 65 |
+
email=user_in.email,
|
| 66 |
+
hashed_password=hashed_password,
|
| 67 |
+
full_name=user_in.full_name
|
| 68 |
+
)
|
| 69 |
+
db.add(new_user)
|
| 70 |
+
db.commit()
|
| 71 |
+
db.refresh(new_user)
|
| 72 |
+
return new_user
|
| 73 |
+
|
| 74 |
+
@router.post("/login", response_model=Token)
|
| 75 |
+
@limiter.limit("5/minute")
|
| 76 |
+
async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
|
| 77 |
+
"""Login to get access token"""
|
| 78 |
+
user = db.query(User).filter(User.email == form_data.username).first()
|
| 79 |
+
if not user or not verify_password(form_data.password, user.hashed_password):
|
| 80 |
+
raise HTTPException(
|
| 81 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 82 |
+
detail="Incorrect email or password",
|
| 83 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 87 |
+
access_token = create_access_token(
|
| 88 |
+
subject=user.id, expires_delta=access_token_expires
|
| 89 |
+
)
|
| 90 |
+
return {"access_token": access_token, "token_type": "bearer"}
|
| 91 |
+
|
| 92 |
+
@router.post("/api-keys", response_model=ApiKeyOut)
|
| 93 |
+
async def create_api_key(
|
| 94 |
+
key_in: ApiKeyCreate,
|
| 95 |
+
current_user: User = Depends(get_current_active_user),
|
| 96 |
+
db: Session = Depends(get_db)
|
| 97 |
+
):
|
| 98 |
+
"""Generate a new API key for the current user"""
|
| 99 |
+
# Generate secure 32-char key
|
| 100 |
+
raw_key = secrets.token_urlsafe(32)
|
| 101 |
+
api_key_str = f"vf_{raw_key}" # Prefix for identification
|
| 102 |
+
|
| 103 |
+
new_key = ApiKey(
|
| 104 |
+
key=api_key_str,
|
| 105 |
+
name=key_in.name,
|
| 106 |
+
user_id=current_user.id
|
| 107 |
+
)
|
| 108 |
+
db.add(new_key)
|
| 109 |
+
db.commit()
|
| 110 |
+
db.refresh(new_key)
|
| 111 |
+
return new_key
|
| 112 |
+
|
| 113 |
+
@router.get("/me", response_model=UserOut)
|
| 114 |
+
async def read_users_me(current_user: User = Depends(get_current_active_user)):
|
| 115 |
+
"""Get current user details"""
|
| 116 |
+
return current_user
|
backend/app/api/routes/batch.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch Processing API Routes
|
| 3 |
+
Endpoints for submitting and managing batch transcription jobs
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends, BackgroundTasks
|
| 7 |
+
from fastapi.responses import FileResponse
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
from typing import List, Optional, Dict, Any
|
| 10 |
+
import logging
|
| 11 |
+
import shutil
|
| 12 |
+
import os
|
| 13 |
+
import tempfile
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
from app.services.batch_service import get_batch_service
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
router = APIRouter(prefix="/batch", tags=["batch"])
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Request/Response Models
|
| 23 |
+
class BatchJobResponse(BaseModel):
|
| 24 |
+
"""Batch job response model."""
|
| 25 |
+
job_id: str
|
| 26 |
+
status: str
|
| 27 |
+
progress: float
|
| 28 |
+
created_at: str
|
| 29 |
+
total_files: int
|
| 30 |
+
completed_files: int
|
| 31 |
+
failed_files: int
|
| 32 |
+
has_zip: bool
|
| 33 |
+
files: Optional[Dict[str, Any]] = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Endpoints
|
| 37 |
+
@router.post("/transcribe", response_model=BatchJobResponse)
|
| 38 |
+
async def create_batch_job(
|
| 39 |
+
background_tasks: BackgroundTasks,
|
| 40 |
+
files: List[UploadFile] = File(..., description="Audio files to transcribe"),
|
| 41 |
+
language: Optional[str] = Form(None, description="Language code (e.g., 'en', 'hi')"),
|
| 42 |
+
output_format: str = Form("txt", description="Output format (txt, srt)"),
|
| 43 |
+
):
|
| 44 |
+
"""
|
| 45 |
+
Submit a batch of audio files for transcription.
|
| 46 |
+
|
| 47 |
+
1. Uploads multiple files
|
| 48 |
+
2. Creates a batch job
|
| 49 |
+
3. Starts processing in background
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
files: List of audio files
|
| 53 |
+
language: Optional language code
|
| 54 |
+
output_format: Output format (txt or srt)
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Created job details
|
| 58 |
+
"""
|
| 59 |
+
if not files:
|
| 60 |
+
raise HTTPException(status_code=400, detail="No files provided")
|
| 61 |
+
|
| 62 |
+
if len(files) > 50:
|
| 63 |
+
raise HTTPException(status_code=400, detail="Maximum 50 files per batch")
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
service = get_batch_service()
|
| 67 |
+
|
| 68 |
+
# Create temp files for processing
|
| 69 |
+
file_paths = {}
|
| 70 |
+
original_names = []
|
| 71 |
+
|
| 72 |
+
for file in files:
|
| 73 |
+
suffix = Path(file.filename).suffix or ".wav"
|
| 74 |
+
# Create a named temp file that persists until manually deleted
|
| 75 |
+
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
| 76 |
+
content = await file.read()
|
| 77 |
+
tmp.write(content)
|
| 78 |
+
tmp.close()
|
| 79 |
+
|
| 80 |
+
file_paths[file.filename] = tmp.name
|
| 81 |
+
original_names.append(file.filename)
|
| 82 |
+
|
| 83 |
+
# Create job
|
| 84 |
+
job = service.create_job(
|
| 85 |
+
filenames=original_names,
|
| 86 |
+
options={
|
| 87 |
+
"language": language,
|
| 88 |
+
"output_format": output_format,
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Connect to Celery worker for processing
|
| 93 |
+
from app.workers.tasks import process_audio_file
|
| 94 |
+
|
| 95 |
+
# NOTE: For MVP batch service, we are currently keeping the simplified background_tasks approach
|
| 96 |
+
# because the 'process_audio_file' task defined in tasks.py is for individual files,
|
| 97 |
+
# whereas 'process_job' handles the whole batch logic (zipping etc).
|
| 98 |
+
# To fully migrate, we would need to refactor batch_service to span multiple tasks.
|
| 99 |
+
#
|
| 100 |
+
# For now, let's keep the background_task for the orchestrator, and have the orchestrator
|
| 101 |
+
# call the celery tasks for individual files?
|
| 102 |
+
# Actually, `service.process_job` currently runs synchronously in a background thread.
|
| 103 |
+
# We will leave as is for 3.1 step 1, but we CAN use Celery for the individual transcriptions.
|
| 104 |
+
|
| 105 |
+
# Start processing in background (Orchestrator runs in thread, calls expensive operations)
|
| 106 |
+
background_tasks.add_task(
|
| 107 |
+
service.process_job,
|
| 108 |
+
job_id=job.job_id,
|
| 109 |
+
file_paths=file_paths,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
return job.to_dict()
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
# Cleanup any created temp files on error
|
| 116 |
+
for path in file_paths.values():
|
| 117 |
+
try:
|
| 118 |
+
os.unlink(path)
|
| 119 |
+
except:
|
| 120 |
+
pass
|
| 121 |
+
logger.error(f"Batch job creation failed: {e}")
|
| 122 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@router.get("/jobs", response_model=List[BatchJobResponse])
|
| 126 |
+
async def list_jobs(limit: int = 10):
|
| 127 |
+
"""
|
| 128 |
+
List recent batch jobs.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
limit: Max number of jobs to return
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
List of jobs
|
| 135 |
+
"""
|
| 136 |
+
service = get_batch_service()
|
| 137 |
+
jobs = service.list_jobs(limit)
|
| 138 |
+
return [job.to_dict() for job in jobs]
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
@router.get("/{job_id}", response_model=BatchJobResponse)
|
| 142 |
+
async def get_job_status(job_id: str):
|
| 143 |
+
"""
|
| 144 |
+
Get status of a specific batch job.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
job_id: Job ID
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Job details and progress
|
| 151 |
+
"""
|
| 152 |
+
service = get_batch_service()
|
| 153 |
+
job = service.get_job(job_id)
|
| 154 |
+
|
| 155 |
+
if not job:
|
| 156 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 157 |
+
|
| 158 |
+
return job.to_dict()
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@router.get("/{job_id}/download")
|
| 162 |
+
async def download_results(job_id: str):
|
| 163 |
+
"""
|
| 164 |
+
Download batch job results as ZIP.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
job_id: Job ID
|
| 168 |
+
|
| 169 |
+
Returns:
|
| 170 |
+
ZIP file download
|
| 171 |
+
"""
|
| 172 |
+
service = get_batch_service()
|
| 173 |
+
zip_path = service.get_zip_path(job_id)
|
| 174 |
+
|
| 175 |
+
if not zip_path:
|
| 176 |
+
raise HTTPException(status_code=404, detail="Results not available (job may be processing or failed)")
|
| 177 |
+
|
| 178 |
+
return FileResponse(
|
| 179 |
+
path=zip_path,
|
| 180 |
+
filename=f"batch_{job_id}_results.zip",
|
| 181 |
+
media_type="application/zip",
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
@router.delete("/{job_id}")
|
| 186 |
+
async def delete_job(job_id: str):
|
| 187 |
+
"""
|
| 188 |
+
Delete a batch job and cleanup files.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
job_id: Job ID
|
| 192 |
+
"""
|
| 193 |
+
service = get_batch_service()
|
| 194 |
+
|
| 195 |
+
# Try to cancel first if running
|
| 196 |
+
service.cancel_job(job_id)
|
| 197 |
+
|
| 198 |
+
# Delete data
|
| 199 |
+
success = service.delete_job(job_id)
|
| 200 |
+
|
| 201 |
+
if not success:
|
| 202 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 203 |
+
|
| 204 |
+
return {"status": "deleted", "job_id": job_id}
|
backend/app/api/routes/cloning.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Voice Cloning API Routes
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
|
| 6 |
+
from fastapi.responses import FileResponse
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
import os
|
| 9 |
+
import shutil
|
| 10 |
+
import tempfile
|
| 11 |
+
import uuid
|
| 12 |
+
|
| 13 |
+
from app.services.clone_service import get_clone_service, CloneService
|
| 14 |
+
|
| 15 |
+
router = APIRouter(prefix="/clone", tags=["Voice Cloning"])
|
| 16 |
+
|
| 17 |
+
@router.post("/synthesize")
|
| 18 |
+
async def clone_synthesize(
|
| 19 |
+
text: str = Form(..., description="Text to speak"),
|
| 20 |
+
language: str = Form("en", description="Language code (en, es, fr, de, etc.)"),
|
| 21 |
+
files: List[UploadFile] = File(..., description="Reference audio samples (1-3 files, 3-10s each recommended)"),
|
| 22 |
+
service: CloneService = Depends(get_clone_service)
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Clone a voice from reference audio samples.
|
| 26 |
+
|
| 27 |
+
Uses Coqui XTTS v2.
|
| 28 |
+
WARNING: Heavy operation. May take 5-20 seconds depending on GPU.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Validation
|
| 32 |
+
if not files:
|
| 33 |
+
raise HTTPException(status_code=400, detail="At least one reference audio file is required")
|
| 34 |
+
|
| 35 |
+
temp_files = []
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Save reference files
|
| 39 |
+
for file in files:
|
| 40 |
+
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 41 |
+
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
| 42 |
+
content = await file.read()
|
| 43 |
+
tmp.write(content)
|
| 44 |
+
tmp.close()
|
| 45 |
+
temp_files.append(tmp.name)
|
| 46 |
+
|
| 47 |
+
# Generate output path
|
| 48 |
+
output_filename = f"cloned_{uuid.uuid4()}.wav"
|
| 49 |
+
output_path = os.path.join(tempfile.gettempdir(), output_filename)
|
| 50 |
+
|
| 51 |
+
# Synthesize
|
| 52 |
+
service.clone_voice(
|
| 53 |
+
text=text,
|
| 54 |
+
speaker_wav_paths=temp_files,
|
| 55 |
+
language=language,
|
| 56 |
+
output_path=output_path
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return FileResponse(
|
| 60 |
+
output_path,
|
| 61 |
+
filename="cloned_speech.wav",
|
| 62 |
+
media_type="audio/wav"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
except ImportError:
|
| 66 |
+
raise HTTPException(status_code=503, detail="Voice Cloning service not available (TTS library missing)")
|
| 67 |
+
except Exception as e:
|
| 68 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 69 |
+
|
| 70 |
+
finally:
|
| 71 |
+
# Cleanup input files
|
| 72 |
+
for p in temp_files:
|
| 73 |
+
try:
|
| 74 |
+
os.unlink(p)
|
| 75 |
+
except:
|
| 76 |
+
pass
|
| 77 |
+
# Note: Output file cleanup needs management in prod (background task or stream)
|
| 78 |
+
|
| 79 |
+
@router.get("/languages")
|
| 80 |
+
def get_languages(service: CloneService = Depends(get_clone_service)):
|
| 81 |
+
return {"languages": service.get_supported_languages()}
|
backend/app/api/routes/health.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Health Check Router
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/health", tags=["Health"])
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@router.get("")
|
| 11 |
+
@router.get("/")
|
| 12 |
+
async def health_check():
|
| 13 |
+
"""Basic health check endpoint"""
|
| 14 |
+
return {
|
| 15 |
+
"status": "healthy",
|
| 16 |
+
"service": "voiceforge-api",
|
| 17 |
+
"version": "1.0.0",
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@router.get("/ready")
|
| 22 |
+
async def readiness_check():
|
| 23 |
+
"""Readiness check - verifies all dependencies are available"""
|
| 24 |
+
# TODO: Check database, Redis, Google Cloud connectivity
|
| 25 |
+
return {
|
| 26 |
+
"status": "ready",
|
| 27 |
+
"checks": {
|
| 28 |
+
"database": "ok",
|
| 29 |
+
"redis": "ok",
|
| 30 |
+
"google_cloud": "ok",
|
| 31 |
+
}
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@router.get("/memory")
|
| 36 |
+
async def memory_status():
|
| 37 |
+
"""Get current memory usage and loaded models"""
|
| 38 |
+
from ...services.whisper_stt_service import (
|
| 39 |
+
_whisper_models,
|
| 40 |
+
_model_last_used,
|
| 41 |
+
get_memory_usage_mb
|
| 42 |
+
)
|
| 43 |
+
import time
|
| 44 |
+
|
| 45 |
+
current_time = time.time()
|
| 46 |
+
models_info = {}
|
| 47 |
+
|
| 48 |
+
for name in _whisper_models.keys():
|
| 49 |
+
last_used = _model_last_used.get(name, 0)
|
| 50 |
+
idle_seconds = current_time - last_used if last_used else 0
|
| 51 |
+
models_info[name] = {
|
| 52 |
+
"loaded": True,
|
| 53 |
+
"idle_seconds": round(idle_seconds, 1)
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
return {
|
| 57 |
+
"memory_mb": round(get_memory_usage_mb(), 1),
|
| 58 |
+
"loaded_models": list(_whisper_models.keys()),
|
| 59 |
+
"models_detail": models_info
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
@router.post("/memory/cleanup")
|
| 64 |
+
async def cleanup_memory():
|
| 65 |
+
"""Unload idle models to free memory"""
|
| 66 |
+
from ...services.whisper_stt_service import cleanup_idle_models, get_memory_usage_mb
|
| 67 |
+
|
| 68 |
+
before = get_memory_usage_mb()
|
| 69 |
+
cleanup_idle_models()
|
| 70 |
+
after = get_memory_usage_mb()
|
| 71 |
+
|
| 72 |
+
return {
|
| 73 |
+
"memory_before_mb": round(before, 1),
|
| 74 |
+
"memory_after_mb": round(after, 1),
|
| 75 |
+
"freed_mb": round(before - after, 1)
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@router.post("/memory/unload-all")
|
| 80 |
+
async def unload_all():
|
| 81 |
+
"""Unload ALL models to free maximum memory"""
|
| 82 |
+
from ...services.whisper_stt_service import unload_all_models, get_memory_usage_mb
|
| 83 |
+
|
| 84 |
+
before = get_memory_usage_mb()
|
| 85 |
+
unloaded = unload_all_models()
|
| 86 |
+
after = get_memory_usage_mb()
|
| 87 |
+
|
| 88 |
+
return {
|
| 89 |
+
"unloaded_models": unloaded,
|
| 90 |
+
"memory_before_mb": round(before, 1),
|
| 91 |
+
"memory_after_mb": round(after, 1),
|
| 92 |
+
"freed_mb": round(before - after, 1)
|
| 93 |
+
}
|
backend/app/api/routes/s2s.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Speech-to-Speech API Router
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Optional
|
| 7 |
+
from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException
|
| 8 |
+
|
| 9 |
+
from app.services.speech_bridge_service import get_bridge_service, SpeechBridgeService
|
| 10 |
+
from app.core.config import get_settings
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
router = APIRouter(prefix="/s2s", tags=["Speech-to-Speech"])
|
| 14 |
+
settings = get_settings()
|
| 15 |
+
|
| 16 |
+
@router.post("/process")
|
| 17 |
+
async def process_speech_to_speech(
|
| 18 |
+
file: UploadFile = File(..., description="Audio file to process"),
|
| 19 |
+
source_lang: str = Form("en", description="Source language code (e.g. 'en', 'hi')"),
|
| 20 |
+
target_lang: str = Form("es", description="Target language code (e.g. 'es', 'fr')"),
|
| 21 |
+
voice_id: Optional[str] = Form(None, description="Target TTS Voice ID"),
|
| 22 |
+
bridge_service: SpeechBridgeService = Depends(get_bridge_service)
|
| 23 |
+
):
|
| 24 |
+
"""
|
| 25 |
+
Process audio: Speech -> Text -> Translation -> Speech
|
| 26 |
+
"""
|
| 27 |
+
try:
|
| 28 |
+
# Read audio file
|
| 29 |
+
audio_bytes = await file.read()
|
| 30 |
+
|
| 31 |
+
result = await bridge_service.process_speech_to_speech(
|
| 32 |
+
audio_bytes=audio_bytes,
|
| 33 |
+
source_lang=source_lang,
|
| 34 |
+
target_lang=target_lang,
|
| 35 |
+
voice_id=voice_id
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
if "error" in result:
|
| 39 |
+
raise HTTPException(status_code=400, detail=result["error"])
|
| 40 |
+
|
| 41 |
+
return result
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error(f"S2S API Error: {e}")
|
| 45 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/app/api/routes/sign.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sign Language API Routes
|
| 3 |
+
Provides WebSocket and REST endpoints for ASL recognition.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException
|
| 7 |
+
from fastapi.responses import JSONResponse
|
| 8 |
+
import numpy as np
|
| 9 |
+
import base64
|
| 10 |
+
import cv2
|
| 11 |
+
import logging
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
from ...services.sign_recognition_service import get_sign_service, SignPrediction
|
| 15 |
+
from ...services.sign_avatar_service import get_avatar_service
|
| 16 |
+
from pydantic import BaseModel
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
router = APIRouter(prefix="/sign", tags=["Sign Language"])
|
| 21 |
+
|
| 22 |
+
class TextToSignRequest(BaseModel):
|
| 23 |
+
text: str
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.get("/health")
|
| 27 |
+
async def sign_health():
|
| 28 |
+
"""Check if sign recognition service is available"""
|
| 29 |
+
try:
|
| 30 |
+
service = get_sign_service()
|
| 31 |
+
return {"status": "ready", "service": "SignRecognitionService"}
|
| 32 |
+
except Exception as e:
|
| 33 |
+
return {"status": "error", "message": str(e)}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@router.post("/recognize")
|
| 37 |
+
async def recognize_sign(file: UploadFile = File(..., description="Image of hand sign")):
|
| 38 |
+
"""
|
| 39 |
+
Recognize ASL letter from a single image.
|
| 40 |
+
|
| 41 |
+
Upload an image containing a hand sign to get the predicted letter.
|
| 42 |
+
"""
|
| 43 |
+
try:
|
| 44 |
+
# Read image
|
| 45 |
+
contents = await file.read()
|
| 46 |
+
nparr = np.frombuffer(contents, np.uint8)
|
| 47 |
+
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 48 |
+
|
| 49 |
+
if image is None:
|
| 50 |
+
raise HTTPException(status_code=400, detail="Invalid image file")
|
| 51 |
+
|
| 52 |
+
# Get predictions
|
| 53 |
+
service = get_sign_service()
|
| 54 |
+
predictions = service.process_frame(image)
|
| 55 |
+
|
| 56 |
+
if not predictions:
|
| 57 |
+
return JSONResponse({
|
| 58 |
+
"success": True,
|
| 59 |
+
"predictions": [],
|
| 60 |
+
"message": "No hands detected in image"
|
| 61 |
+
})
|
| 62 |
+
|
| 63 |
+
return JSONResponse({
|
| 64 |
+
"success": True,
|
| 65 |
+
"predictions": [
|
| 66 |
+
{
|
| 67 |
+
"letter": p.letter,
|
| 68 |
+
"confidence": p.confidence
|
| 69 |
+
}
|
| 70 |
+
for p in predictions
|
| 71 |
+
]
|
| 72 |
+
})
|
| 73 |
+
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.error(f"Sign recognition error: {e}")
|
| 76 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@router.websocket("/live")
|
| 80 |
+
async def sign_websocket(websocket: WebSocket):
|
| 81 |
+
"""
|
| 82 |
+
WebSocket endpoint for real-time sign language recognition.
|
| 83 |
+
|
| 84 |
+
Client sends base64-encoded JPEG frames, server responds with predictions.
|
| 85 |
+
|
| 86 |
+
Protocol:
|
| 87 |
+
- Client sends: {"frame": "<base64 jpeg>"}
|
| 88 |
+
- Server sends: {"predictions": [{"letter": "A", "confidence": 0.8}]}
|
| 89 |
+
"""
|
| 90 |
+
await websocket.accept()
|
| 91 |
+
service = get_sign_service()
|
| 92 |
+
|
| 93 |
+
logger.info("Sign language WebSocket connected")
|
| 94 |
+
|
| 95 |
+
try:
|
| 96 |
+
while True:
|
| 97 |
+
# Receive frame from client
|
| 98 |
+
data = await websocket.receive_json()
|
| 99 |
+
|
| 100 |
+
if "frame" not in data:
|
| 101 |
+
await websocket.send_json({"error": "Missing 'frame' field"})
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
# Decode base64 image
|
| 105 |
+
try:
|
| 106 |
+
frame_data = base64.b64decode(data["frame"])
|
| 107 |
+
nparr = np.frombuffer(frame_data, np.uint8)
|
| 108 |
+
frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 109 |
+
|
| 110 |
+
if frame is None:
|
| 111 |
+
await websocket.send_json({"error": "Invalid frame data"})
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
except Exception as e:
|
| 115 |
+
await websocket.send_json({"error": f"Frame decode error: {e}"})
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
# Process frame
|
| 119 |
+
predictions = service.process_frame(frame)
|
| 120 |
+
|
| 121 |
+
# Send results
|
| 122 |
+
await websocket.send_json({
|
| 123 |
+
"predictions": [
|
| 124 |
+
{
|
| 125 |
+
"letter": p.letter,
|
| 126 |
+
"confidence": round(p.confidence, 2)
|
| 127 |
+
}
|
| 128 |
+
for p in predictions
|
| 129 |
+
]
|
| 130 |
+
})
|
| 131 |
+
|
| 132 |
+
except WebSocketDisconnect:
|
| 133 |
+
logger.info("Sign language WebSocket disconnected")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"WebSocket error: {e}")
|
| 136 |
+
await websocket.close(code=1011, reason=str(e))
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@router.get("/alphabet")
|
| 140 |
+
async def get_alphabet():
|
| 141 |
+
"""Get list of supported ASL letters"""
|
| 142 |
+
return {
|
| 143 |
+
"supported_letters": list("ABCDILUVWY5"), # Currently implemented
|
| 144 |
+
"note": "J and Z require motion tracking (coming soon)"
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@router.post("/animate")
|
| 149 |
+
async def animate_text(request: TextToSignRequest):
|
| 150 |
+
"""
|
| 151 |
+
Convert text to sign language animation sequence (Finger Spelling).
|
| 152 |
+
"""
|
| 153 |
+
try:
|
| 154 |
+
service = get_avatar_service()
|
| 155 |
+
sequence = service.text_to_glosses(request.text)
|
| 156 |
+
|
| 157 |
+
return {
|
| 158 |
+
"success": True,
|
| 159 |
+
"sequence": sequence,
|
| 160 |
+
"count": len(sequence)
|
| 161 |
+
}
|
| 162 |
+
except Exception as e:
|
| 163 |
+
logger.error(f"Animation error: {e}")
|
| 164 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/app/api/routes/sign_bridge.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sign-to-Speech API Router
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Optional, Dict, Any
|
| 7 |
+
from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, Body
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from app.services.sign_bridge_service import get_sign_bridge_service, SignBridgeService
|
| 11 |
+
from app.core.config import get_settings
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
router = APIRouter(prefix="/sign-bridge", tags=["Sign-to-Speech"])
|
| 15 |
+
settings = get_settings()
|
| 16 |
+
|
| 17 |
+
class SignTextRequest(BaseModel):
|
| 18 |
+
text: str
|
| 19 |
+
voice_id: Optional[str] = "en-US-AriaNeural"
|
| 20 |
+
|
| 21 |
+
@router.post("/speak")
|
| 22 |
+
async def speak_sign_text(
|
| 23 |
+
request: SignTextRequest,
|
| 24 |
+
bridge_service: SignBridgeService = Depends(get_sign_bridge_service)
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Speak text derived from Sign Language.
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
result = await bridge_service.speak_text(
|
| 31 |
+
text=request.text,
|
| 32 |
+
voice_id=request.voice_id
|
| 33 |
+
)
|
| 34 |
+
if "error" in result:
|
| 35 |
+
raise HTTPException(status_code=400, detail=result["error"])
|
| 36 |
+
return result
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.error(f"Sign-Bridge Speak Error: {e}")
|
| 39 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 40 |
+
|
| 41 |
+
@router.post("/process-frame")
|
| 42 |
+
async def process_frame(
|
| 43 |
+
file: UploadFile = File(...),
|
| 44 |
+
bridge_service: SignBridgeService = Depends(get_sign_bridge_service)
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Process a video frame for sign recognition (Backend-side).
|
| 48 |
+
Note: For real-time, client-side MediaPipe is preferred.
|
| 49 |
+
"""
|
| 50 |
+
try:
|
| 51 |
+
# Read image
|
| 52 |
+
import numpy as np
|
| 53 |
+
import cv2
|
| 54 |
+
|
| 55 |
+
contents = await file.read()
|
| 56 |
+
nparr = np.frombuffer(contents, np.uint8)
|
| 57 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 58 |
+
|
| 59 |
+
result = await bridge_service.process_sign_frame(img)
|
| 60 |
+
return result
|
| 61 |
+
except Exception as e:
|
| 62 |
+
logger.error(f"Sign-Bridge Frame Error: {e}")
|
| 63 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/app/api/routes/stt.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Speech-to-Text API Router
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
from typing import Optional, List
|
| 8 |
+
|
| 9 |
+
from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends, Request
|
| 10 |
+
from fastapi.responses import JSONResponse
|
| 11 |
+
|
| 12 |
+
from ...core.limiter import limiter
|
| 13 |
+
|
| 14 |
+
from ...services.stt_service import get_stt_service, STTService
|
| 15 |
+
from ...services.file_service import get_file_service, FileService
|
| 16 |
+
from ...schemas.stt import (
|
| 17 |
+
TranscriptionResponse,
|
| 18 |
+
TranscriptionRequest,
|
| 19 |
+
LanguageInfo,
|
| 20 |
+
LanguageListResponse,
|
| 21 |
+
)
|
| 22 |
+
from ...core.config import get_settings
|
| 23 |
+
from sqlalchemy.orm import Session
|
| 24 |
+
from app.models import get_db, AudioFile, Transcript
|
| 25 |
+
from ...workers.tasks import process_audio_file
|
| 26 |
+
from celery.result import AsyncResult
|
| 27 |
+
from ...schemas.stt import (
|
| 28 |
+
TranscriptionResponse,
|
| 29 |
+
TranscriptionRequest,
|
| 30 |
+
LanguageInfo,
|
| 31 |
+
LanguageListResponse,
|
| 32 |
+
AsyncTranscriptionResponse,
|
| 33 |
+
TaskStatusResponse,
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
router = APIRouter(prefix="/stt", tags=["Speech-to-Text"])
|
| 39 |
+
settings = get_settings()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@router.get("/languages", response_model=LanguageListResponse)
|
| 43 |
+
async def get_supported_languages(
|
| 44 |
+
stt_service: STTService = Depends(get_stt_service),
|
| 45 |
+
):
|
| 46 |
+
"""
|
| 47 |
+
Get list of supported languages for speech-to-text
|
| 48 |
+
"""
|
| 49 |
+
languages = stt_service.get_supported_languages()
|
| 50 |
+
return LanguageListResponse(
|
| 51 |
+
languages=languages,
|
| 52 |
+
total=len(languages),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@router.post("/upload", response_model=TranscriptionResponse)
|
| 57 |
+
@limiter.limit("10/minute")
|
| 58 |
+
async def transcribe_upload(
|
| 59 |
+
request: Request,
|
| 60 |
+
file: UploadFile = File(..., description="Audio file to transcribe"),
|
| 61 |
+
language: str = Form(default="en-US", description="Language code"),
|
| 62 |
+
enable_punctuation: bool = Form(default=True, description="Enable automatic punctuation"),
|
| 63 |
+
enable_word_timestamps: bool = Form(default=True, description="Include word-level timestamps"),
|
| 64 |
+
enable_diarization: bool = Form(default=False, description="Enable speaker diarization"),
|
| 65 |
+
speaker_count: Optional[int] = Form(default=None, description="Expected number of speakers"),
|
| 66 |
+
prompt: Optional[str] = Form(None, description="Custom vocabulary/keywords (e.g. 'VoiceForge, PyTorch')"),
|
| 67 |
+
stt_service: STTService = Depends(get_stt_service),
|
| 68 |
+
file_service: FileService = Depends(get_file_service),
|
| 69 |
+
db: Session = Depends(get_db),
|
| 70 |
+
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Transcribe an uploaded audio file
|
| 74 |
+
|
| 75 |
+
Supports: WAV, MP3, M4A, FLAC, OGG, WebM
|
| 76 |
+
|
| 77 |
+
For files longer than 1 minute, consider using the async endpoint.
|
| 78 |
+
"""
|
| 79 |
+
# Validate file type
|
| 80 |
+
if not file.filename:
|
| 81 |
+
raise HTTPException(status_code=400, detail="No filename provided")
|
| 82 |
+
|
| 83 |
+
ext = file.filename.split(".")[-1].lower()
|
| 84 |
+
if ext not in settings.supported_audio_formats_list:
|
| 85 |
+
raise HTTPException(
|
| 86 |
+
status_code=400,
|
| 87 |
+
detail=f"Unsupported format: {ext}. Supported: {', '.join(settings.supported_audio_formats_list)}"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Validate language
|
| 91 |
+
if language not in settings.supported_languages_list:
|
| 92 |
+
raise HTTPException(
|
| 93 |
+
status_code=400,
|
| 94 |
+
detail=f"Unsupported language: {language}. Supported: {', '.join(settings.supported_languages_list)}"
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
# Read file content
|
| 99 |
+
content = await file.read()
|
| 100 |
+
|
| 101 |
+
# Save to storage
|
| 102 |
+
storage_path, metadata = file_service.save_upload(
|
| 103 |
+
file_content=content,
|
| 104 |
+
original_filename=file.filename,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
logger.info(f"Processing transcription for {file.filename} ({len(content)} bytes)")
|
| 108 |
+
|
| 109 |
+
# Perform transcription
|
| 110 |
+
result = stt_service.transcribe_file(
|
| 111 |
+
audio_path=storage_path,
|
| 112 |
+
language=language,
|
| 113 |
+
enable_automatic_punctuation=enable_punctuation,
|
| 114 |
+
enable_word_time_offsets=enable_word_timestamps,
|
| 115 |
+
enable_speaker_diarization=enable_diarization,
|
| 116 |
+
diarization_speaker_count=speaker_count,
|
| 117 |
+
sample_rate=metadata.get("sample_rate"),
|
| 118 |
+
prompt=prompt, # Custom vocabulary
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Clean up temp file (optional - could keep for history)
|
| 122 |
+
# file_service.delete_file(storage_path)
|
| 123 |
+
|
| 124 |
+
# Save to database
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# 1. Create AudioFile record
|
| 128 |
+
audio_file = AudioFile(
|
| 129 |
+
storage_path=str(storage_path),
|
| 130 |
+
original_filename=file.filename,
|
| 131 |
+
duration=result.duration,
|
| 132 |
+
format=ext,
|
| 133 |
+
sample_rate=metadata.get("sample_rate"),
|
| 134 |
+
language=language,
|
| 135 |
+
detected_language=result.language,
|
| 136 |
+
status="done"
|
| 137 |
+
)
|
| 138 |
+
db.add(audio_file)
|
| 139 |
+
db.flush() # get ID
|
| 140 |
+
|
| 141 |
+
# 2. Create Transcript record
|
| 142 |
+
transcript = Transcript(
|
| 143 |
+
audio_file_id=audio_file.id,
|
| 144 |
+
raw_text=result.text,
|
| 145 |
+
processed_text=result.text, # initially same
|
| 146 |
+
segments=[s.model_dump() for s in result.segments] if result.segments else [],
|
| 147 |
+
language=result.language,
|
| 148 |
+
created_at=datetime.utcnow(),
|
| 149 |
+
)
|
| 150 |
+
db.add(transcript)
|
| 151 |
+
db.commit()
|
| 152 |
+
db.refresh(transcript)
|
| 153 |
+
|
| 154 |
+
# Return result with ID
|
| 155 |
+
response_data = result.model_dump()
|
| 156 |
+
response_data["id"] = transcript.id
|
| 157 |
+
|
| 158 |
+
# Explicitly validate to catch errors early
|
| 159 |
+
try:
|
| 160 |
+
return TranscriptionResponse(**response_data)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
logger.error(f"Validation error for response: {e}")
|
| 163 |
+
logger.error(f"Response data: {response_data}")
|
| 164 |
+
raise HTTPException(status_code=500, detail=f"Response validation failed: {str(e)}")
|
| 165 |
+
# return response - removed undefined variable
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Failed to save to DB: {e}")
|
| 169 |
+
# Don't fail the request if DB save fails, just return result
|
| 170 |
+
# But in production we might want to ensure persistence
|
| 171 |
+
return result
|
| 172 |
+
|
| 173 |
+
except FileNotFoundError as e:
|
| 174 |
+
logger.error(f"File error: {e}")
|
| 175 |
+
raise HTTPException(status_code=404, detail=str(e))
|
| 176 |
+
except ValueError as e:
|
| 177 |
+
logger.error(f"Validation error: {e}")
|
| 178 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 179 |
+
except Exception as e:
|
| 180 |
+
logger.exception(f"Transcription failed: {e}")
|
| 181 |
+
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@router.post("/upload/quality")
|
| 185 |
+
async def transcribe_quality(
|
| 186 |
+
file: UploadFile = File(..., description="Audio file to transcribe"),
|
| 187 |
+
language: str = Form(default="en-US", description="Language code"),
|
| 188 |
+
preprocess: bool = Form(default=False, description="Apply noise reduction (5-15% WER improvement)"),
|
| 189 |
+
prompt: Optional[str] = Form(None, description="Custom vocabulary/keywords"),
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
High-quality transcription mode (optimized for accuracy).
|
| 193 |
+
|
| 194 |
+
Features:
|
| 195 |
+
- beam_size=5 for more accurate decoding (~40% fewer errors)
|
| 196 |
+
- condition_on_previous_text=False to reduce hallucinations
|
| 197 |
+
- Optional audio preprocessing for noisy environments
|
| 198 |
+
|
| 199 |
+
Trade-off: ~2x slower than standard mode
|
| 200 |
+
Best for: Important recordings, noisy audio, reduced error tolerance
|
| 201 |
+
"""
|
| 202 |
+
from app.services.whisper_stt_service import get_whisper_stt_service
|
| 203 |
+
import tempfile
|
| 204 |
+
import os
|
| 205 |
+
|
| 206 |
+
# Validate file
|
| 207 |
+
if not file.filename:
|
| 208 |
+
raise HTTPException(status_code=400, detail="No filename provided")
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
content = await file.read()
|
| 212 |
+
|
| 213 |
+
# Save to temp file
|
| 214 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 215 |
+
f.write(content)
|
| 216 |
+
temp_path = f.name
|
| 217 |
+
|
| 218 |
+
try:
|
| 219 |
+
stt_service = get_whisper_stt_service()
|
| 220 |
+
result = stt_service.transcribe_quality(
|
| 221 |
+
temp_path,
|
| 222 |
+
language=language,
|
| 223 |
+
preprocess=preprocess,
|
| 224 |
+
prompt=prompt,
|
| 225 |
+
)
|
| 226 |
+
return result
|
| 227 |
+
finally:
|
| 228 |
+
try:
|
| 229 |
+
os.unlink(temp_path)
|
| 230 |
+
except:
|
| 231 |
+
pass
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
logger.exception(f"Quality transcription failed: {e}")
|
| 235 |
+
raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@router.post("/upload/batch")
|
| 239 |
+
async def transcribe_batch(
|
| 240 |
+
files: List[UploadFile] = File(..., description="Multiple audio files to transcribe"),
|
| 241 |
+
language: str = Form(default="en-US", description="Language code"),
|
| 242 |
+
batch_size: int = Form(default=8, description="Batch size (8 optimal for CPU)"),
|
| 243 |
+
):
|
| 244 |
+
"""
|
| 245 |
+
Batch transcription for high throughput.
|
| 246 |
+
|
| 247 |
+
Uses BatchedInferencePipeline for 2-3x speedup on concurrent files.
|
| 248 |
+
|
| 249 |
+
Best for: Processing multiple files, API with high concurrency
|
| 250 |
+
"""
|
| 251 |
+
from app.services.whisper_stt_service import get_whisper_stt_service
|
| 252 |
+
import tempfile
|
| 253 |
+
import os
|
| 254 |
+
|
| 255 |
+
if not files:
|
| 256 |
+
raise HTTPException(status_code=400, detail="No files provided")
|
| 257 |
+
|
| 258 |
+
results = []
|
| 259 |
+
stt_service = get_whisper_stt_service()
|
| 260 |
+
|
| 261 |
+
for file in files:
|
| 262 |
+
if not file.filename:
|
| 263 |
+
continue
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
content = await file.read()
|
| 267 |
+
|
| 268 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 269 |
+
f.write(content)
|
| 270 |
+
temp_path = f.name
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
result = stt_service.transcribe_batched(
|
| 274 |
+
temp_path,
|
| 275 |
+
language=language,
|
| 276 |
+
batch_size=batch_size,
|
| 277 |
+
)
|
| 278 |
+
result["filename"] = file.filename
|
| 279 |
+
results.append(result)
|
| 280 |
+
finally:
|
| 281 |
+
try:
|
| 282 |
+
os.unlink(temp_path)
|
| 283 |
+
except:
|
| 284 |
+
pass
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"Failed to transcribe {file.filename}: {e}")
|
| 288 |
+
results.append({
|
| 289 |
+
"filename": file.filename,
|
| 290 |
+
"error": str(e),
|
| 291 |
+
})
|
| 292 |
+
|
| 293 |
+
return {
|
| 294 |
+
"count": len(results),
|
| 295 |
+
"results": results,
|
| 296 |
+
"mode": "batched",
|
| 297 |
+
"batch_size": batch_size,
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
@router.post("/async-upload", response_model=AsyncTranscriptionResponse)
|
| 302 |
+
async def transcribe_async_upload(
|
| 303 |
+
file: UploadFile = File(..., description="Audio file to transcribe"),
|
| 304 |
+
language: str = Form(default="en-US", description="Language code"),
|
| 305 |
+
file_service: FileService = Depends(get_file_service),
|
| 306 |
+
db: Session = Depends(get_db),
|
| 307 |
+
):
|
| 308 |
+
"""
|
| 309 |
+
Asynchronously transcribe an uploaded audio file (Celery)
|
| 310 |
+
"""
|
| 311 |
+
# Validate file type
|
| 312 |
+
if not file.filename:
|
| 313 |
+
raise HTTPException(status_code=400, detail="No filename provided")
|
| 314 |
+
|
| 315 |
+
ext = file.filename.split(".")[-1].lower()
|
| 316 |
+
if ext not in settings.supported_audio_formats_list:
|
| 317 |
+
raise HTTPException(
|
| 318 |
+
status_code=400,
|
| 319 |
+
detail=f"Unsupported format: {ext}"
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
try:
|
| 323 |
+
content = await file.read()
|
| 324 |
+
storage_path, metadata = file_service.save_upload(
|
| 325 |
+
file_content=content,
|
| 326 |
+
original_filename=file.filename,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
# Create AudioFile record with 'queued' status
|
| 330 |
+
audio_file = AudioFile(
|
| 331 |
+
storage_path=str(storage_path),
|
| 332 |
+
original_filename=file.filename,
|
| 333 |
+
duration=0.0, # Will be updated by worker
|
| 334 |
+
format=ext,
|
| 335 |
+
sample_rate=metadata.get("sample_rate"),
|
| 336 |
+
language=language,
|
| 337 |
+
status="queued"
|
| 338 |
+
)
|
| 339 |
+
db.add(audio_file)
|
| 340 |
+
db.commit()
|
| 341 |
+
db.refresh(audio_file)
|
| 342 |
+
|
| 343 |
+
# Trigger Celery Task
|
| 344 |
+
task = process_audio_file.delay(audio_file.id)
|
| 345 |
+
|
| 346 |
+
return AsyncTranscriptionResponse(
|
| 347 |
+
task_id=task.id,
|
| 348 |
+
audio_file_id=audio_file.id,
|
| 349 |
+
status="queued",
|
| 350 |
+
message="File uploaded and queued for processing"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
except Exception as e:
|
| 354 |
+
logger.exception(f"Async upload failed: {e}")
|
| 355 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
@router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
|
| 359 |
+
async def get_task_status(task_id: str, db: Session = Depends(get_db)):
|
| 360 |
+
"""
|
| 361 |
+
Check status of an async transcription task
|
| 362 |
+
"""
|
| 363 |
+
task_result = AsyncResult(task_id)
|
| 364 |
+
|
| 365 |
+
response = TaskStatusResponse(
|
| 366 |
+
task_id=task_id,
|
| 367 |
+
status=task_result.status.lower(),
|
| 368 |
+
created_at=datetime.utcnow(), # Approximate or fetch from DB tracked tasks
|
| 369 |
+
updated_at=datetime.utcnow()
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
if task_result.successful():
|
| 373 |
+
# If successful, the result of the task function isn't returned directly
|
| 374 |
+
# because process_audio_file returns None (it saves to DB).
|
| 375 |
+
# We need to find the Transcript associated with this task if possible.
|
| 376 |
+
# Ideally, we should store task_id in AudioFile or Transcript to link them.
|
| 377 |
+
# For now, we just report completion.
|
| 378 |
+
response.status = "completed"
|
| 379 |
+
response.progress = 100.0
|
| 380 |
+
elif task_result.failed():
|
| 381 |
+
response.status = "failed"
|
| 382 |
+
response.error = str(task_result.result)
|
| 383 |
+
elif task_result.state == 'PROGRESS':
|
| 384 |
+
response.status = "processing"
|
| 385 |
+
# If we had progress updating in the task, we could read it here
|
| 386 |
+
|
| 387 |
+
return response
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
@router.post("/transcribe-bytes", response_model=TranscriptionResponse)
|
| 391 |
+
async def transcribe_bytes(
|
| 392 |
+
audio_content: bytes,
|
| 393 |
+
language: str = "en-US",
|
| 394 |
+
encoding: str = "LINEAR16",
|
| 395 |
+
sample_rate: int = 16000,
|
| 396 |
+
stt_service: STTService = Depends(get_stt_service),
|
| 397 |
+
):
|
| 398 |
+
"""
|
| 399 |
+
Transcribe raw audio bytes (for streaming/real-time use)
|
| 400 |
+
|
| 401 |
+
This endpoint is primarily for internal use or advanced clients
|
| 402 |
+
that send pre-processed audio data.
|
| 403 |
+
"""
|
| 404 |
+
try:
|
| 405 |
+
result = stt_service.transcribe_bytes(
|
| 406 |
+
audio_content=audio_content,
|
| 407 |
+
language=language,
|
| 408 |
+
encoding=encoding,
|
| 409 |
+
sample_rate=sample_rate,
|
| 410 |
+
)
|
| 411 |
+
return result
|
| 412 |
+
except Exception as e:
|
| 413 |
+
logger.exception(f"Transcription failed: {e}")
|
| 414 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# TODO: WebSocket endpoint for real-time streaming
|
| 418 |
+
# @router.websocket("/stream")
|
| 419 |
+
# async def stream_transcription(websocket: WebSocket):
|
| 420 |
+
# """Real-time streaming transcription via WebSocket"""
|
| 421 |
+
# pass
|
| 422 |
+
|
| 423 |
+
@router.post("/upload/diarize")
|
| 424 |
+
async def diarize_audio(
|
| 425 |
+
file: UploadFile = File(..., description="Audio file to diarize"),
|
| 426 |
+
num_speakers: Optional[int] = Form(None, description="Exact number of speakers (optional)"),
|
| 427 |
+
min_speakers: Optional[int] = Form(None, description="Minimum number of speakers (optional)"),
|
| 428 |
+
max_speakers: Optional[int] = Form(None, description="Maximum number of speakers (optional)"),
|
| 429 |
+
language: Optional[str] = Form(None, description="Language code (e.g., 'en'). Auto-detected if not provided."),
|
| 430 |
+
preprocess: bool = Form(False, description="Apply noise reduction before processing (improves accuracy for noisy audio)"),
|
| 431 |
+
):
|
| 432 |
+
"""
|
| 433 |
+
Perform Speaker Diarization ("Who said what").
|
| 434 |
+
|
| 435 |
+
Uses faster-whisper for transcription + pyannote.audio for speaker identification.
|
| 436 |
+
|
| 437 |
+
Requires:
|
| 438 |
+
- HF_TOKEN in .env for Pyannote model access
|
| 439 |
+
|
| 440 |
+
Returns:
|
| 441 |
+
- segments: List of segments with timestamps, text, and speaker labels
|
| 442 |
+
- speaker_stats: Speaking time per speaker
|
| 443 |
+
- language: Detected/specified language
|
| 444 |
+
"""
|
| 445 |
+
from app.services.diarization_service import get_diarization_service
|
| 446 |
+
import tempfile
|
| 447 |
+
import os
|
| 448 |
+
|
| 449 |
+
if not file.filename:
|
| 450 |
+
raise HTTPException(status_code=400, detail="No filename provided")
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
# Save temp file
|
| 454 |
+
content = await file.read()
|
| 455 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 456 |
+
f.write(content)
|
| 457 |
+
temp_path = f.name
|
| 458 |
+
|
| 459 |
+
try:
|
| 460 |
+
service = get_diarization_service()
|
| 461 |
+
result = service.process_audio(
|
| 462 |
+
temp_path,
|
| 463 |
+
num_speakers=num_speakers,
|
| 464 |
+
min_speakers=min_speakers,
|
| 465 |
+
max_speakers=max_speakers,
|
| 466 |
+
language=language,
|
| 467 |
+
preprocess=preprocess,
|
| 468 |
+
)
|
| 469 |
+
return result
|
| 470 |
+
|
| 471 |
+
except ValueError as e:
|
| 472 |
+
# Token missing
|
| 473 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 474 |
+
except ImportError as e:
|
| 475 |
+
# Not installed
|
| 476 |
+
raise HTTPException(status_code=503, detail=str(e))
|
| 477 |
+
except Exception as e:
|
| 478 |
+
logger.exception("Diarization error")
|
| 479 |
+
raise HTTPException(status_code=500, detail=f"Diarization failed: {str(e)}")
|
| 480 |
+
|
| 481 |
+
finally:
|
| 482 |
+
try:
|
| 483 |
+
os.unlink(temp_path)
|
| 484 |
+
except:
|
| 485 |
+
pass
|
| 486 |
+
|
| 487 |
+
except Exception as e:
|
| 488 |
+
logger.error(f"Diarization request failed: {e}")
|
| 489 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/app/api/routes/transcripts.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transcript Management Routes
|
| 3 |
+
CRUD operations and Export
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Optional
|
| 7 |
+
from fastapi import APIRouter, Depends, HTTPException, Response, Query, UploadFile, File, Form
|
| 8 |
+
from sqlalchemy.orm import Session
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from ...models import get_db, Transcript, AudioFile
|
| 12 |
+
from ...schemas.transcript import TranscriptResponse, TranscriptUpdate
|
| 13 |
+
from ...services.nlp_service import get_nlp_service, NLPService
|
| 14 |
+
from ...services.export_service import ExportService
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
router = APIRouter(prefix="/transcripts", tags=["Transcripts"])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.get("", response_model=List[TranscriptResponse])
|
| 21 |
+
async def list_transcripts(
|
| 22 |
+
skip: int = 0,
|
| 23 |
+
limit: int = 100,
|
| 24 |
+
db: Session = Depends(get_db),
|
| 25 |
+
):
|
| 26 |
+
"""List all transcripts"""
|
| 27 |
+
transcripts = db.query(Transcript).order_by(Transcript.created_at.desc()).offset(skip).limit(limit).all()
|
| 28 |
+
return transcripts
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@router.get("/{transcript_id}", response_model=TranscriptResponse)
|
| 32 |
+
async def get_transcript(
|
| 33 |
+
transcript_id: int,
|
| 34 |
+
db: Session = Depends(get_db),
|
| 35 |
+
):
|
| 36 |
+
"""Get specific transcript details"""
|
| 37 |
+
transcript = db.query(Transcript).filter(Transcript.id == transcript_id).first()
|
| 38 |
+
if not transcript:
|
| 39 |
+
raise HTTPException(status_code=404, detail="Transcript not found")
|
| 40 |
+
return transcript
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@router.post("/{transcript_id}/analyze")
|
| 44 |
+
async def analyze_transcript(
|
| 45 |
+
transcript_id: int,
|
| 46 |
+
db: Session = Depends(get_db),
|
| 47 |
+
nlp_service: NLPService = Depends(get_nlp_service),
|
| 48 |
+
):
|
| 49 |
+
"""Run NLP analysis on a transcript"""
|
| 50 |
+
transcript = db.query(Transcript).filter(Transcript.id == transcript_id).first()
|
| 51 |
+
if not transcript:
|
| 52 |
+
raise HTTPException(status_code=404, detail="Transcript not found")
|
| 53 |
+
|
| 54 |
+
if not transcript.processed_text:
|
| 55 |
+
raise HTTPException(status_code=400, detail="Transcript has no text content")
|
| 56 |
+
|
| 57 |
+
# Run analysis
|
| 58 |
+
analysis = nlp_service.process_transcript(transcript.processed_text)
|
| 59 |
+
|
| 60 |
+
# Update DB
|
| 61 |
+
transcript.sentiment = analysis["sentiment"]
|
| 62 |
+
transcript.topics = {"keywords": analysis["keywords"]}
|
| 63 |
+
transcript.summary = analysis["summary"]
|
| 64 |
+
transcript.updated_at = datetime.utcnow()
|
| 65 |
+
|
| 66 |
+
db.commit()
|
| 67 |
+
db.refresh(transcript)
|
| 68 |
+
|
| 69 |
+
return {
|
| 70 |
+
"status": "success",
|
| 71 |
+
"analysis": analysis
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@router.get("/{transcript_id}/export")
|
| 76 |
+
async def export_transcript(
|
| 77 |
+
transcript_id: int,
|
| 78 |
+
format: str = Query(..., regex="^(txt|srt|vtt|pdf)$"),
|
| 79 |
+
db: Session = Depends(get_db),
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Export transcript to specific format
|
| 83 |
+
"""
|
| 84 |
+
transcript = db.query(Transcript).filter(Transcript.id == transcript_id).first()
|
| 85 |
+
if not transcript:
|
| 86 |
+
raise HTTPException(status_code=404, detail="Transcript not found")
|
| 87 |
+
|
| 88 |
+
# Convert model to dict for service
|
| 89 |
+
data = {
|
| 90 |
+
"id": transcript.id,
|
| 91 |
+
"text": transcript.processed_text,
|
| 92 |
+
"created_at": str(transcript.created_at),
|
| 93 |
+
"duration": 0,
|
| 94 |
+
"segments": transcript.segments,
|
| 95 |
+
"words": [],
|
| 96 |
+
"sentiment": transcript.sentiment,
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
if format == "txt":
|
| 100 |
+
content = ExportService.to_txt(data)
|
| 101 |
+
media_type = "text/plain"
|
| 102 |
+
elif format == "srt":
|
| 103 |
+
content = ExportService.to_srt(data)
|
| 104 |
+
media_type = "text/plain"
|
| 105 |
+
elif format == "vtt":
|
| 106 |
+
content = ExportService.to_vtt(data)
|
| 107 |
+
media_type = "text/vtt"
|
| 108 |
+
elif format == "pdf":
|
| 109 |
+
content = ExportService.to_pdf(data)
|
| 110 |
+
media_type = "application/pdf"
|
| 111 |
+
else:
|
| 112 |
+
raise HTTPException(status_code=400, detail="Unsupported format")
|
| 113 |
+
|
| 114 |
+
return Response(
|
| 115 |
+
content=content,
|
| 116 |
+
media_type=media_type,
|
| 117 |
+
headers={
|
| 118 |
+
"Content-Disposition": f'attachment; filename="transcript_{transcript_id}.{format}"'
|
| 119 |
+
}
|
| 120 |
+
)
|
| 121 |
+
@router.post("/meeting")
|
| 122 |
+
async def process_meeting(
|
| 123 |
+
file: UploadFile = File(..., description="Audio recording of meeting"),
|
| 124 |
+
num_speakers: Optional[int] = Form(None, description="Number of speakers (hint)"),
|
| 125 |
+
language: Optional[str] = Form(None, description="Language code"),
|
| 126 |
+
db: Session = Depends(get_db),
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Process a meeting recording:
|
| 130 |
+
1. Diarization (Who spoke when)
|
| 131 |
+
2. Transcription (What was said)
|
| 132 |
+
3. NLP Analysis (Summary, Action Items, Sentiment)
|
| 133 |
+
4. Save to DB
|
| 134 |
+
"""
|
| 135 |
+
import shutil
|
| 136 |
+
import os
|
| 137 |
+
import tempfile
|
| 138 |
+
from ...services.meeting_service import get_meeting_service
|
| 139 |
+
|
| 140 |
+
# Save upload to temp file
|
| 141 |
+
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 142 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 143 |
+
shutil.copyfileobj(file.file, tmp)
|
| 144 |
+
tmp_path = tmp.name
|
| 145 |
+
|
| 146 |
+
try:
|
| 147 |
+
meeting_service = get_meeting_service()
|
| 148 |
+
|
| 149 |
+
# Run full pipeline
|
| 150 |
+
# This can be slow (minutes) so strictly speaking should be a background task
|
| 151 |
+
# But for this MVP level we'll do it synchronously with a long timeout
|
| 152 |
+
result = meeting_service.process_meeting(
|
| 153 |
+
audio_path=tmp_path,
|
| 154 |
+
num_speakers=num_speakers,
|
| 155 |
+
language=language
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Save to DB
|
| 159 |
+
# Create AudioFile record first
|
| 160 |
+
audio_file = AudioFile(
|
| 161 |
+
filename=file.filename,
|
| 162 |
+
filepath="processed_in_memory", # We delete temp file, so no perm path
|
| 163 |
+
duration=result["metadata"]["duration_seconds"],
|
| 164 |
+
file_size=0,
|
| 165 |
+
format=suffix.replace(".", "")
|
| 166 |
+
)
|
| 167 |
+
db.add(audio_file)
|
| 168 |
+
db.commit()
|
| 169 |
+
db.refresh(audio_file)
|
| 170 |
+
|
| 171 |
+
# Create Transcript record
|
| 172 |
+
transcript = Transcript(
|
| 173 |
+
audio_file_id=audio_file.id,
|
| 174 |
+
raw_text=result["raw_text"],
|
| 175 |
+
processed_text=result["raw_text"],
|
| 176 |
+
segments=result["transcript_segments"],
|
| 177 |
+
sentiment=result["sentiment"],
|
| 178 |
+
topics={"keywords": result["topics"]},
|
| 179 |
+
action_items=result["action_items"],
|
| 180 |
+
attendees=result["metadata"]["attendees"],
|
| 181 |
+
summary=result["summary"],
|
| 182 |
+
language=result["metadata"]["language"],
|
| 183 |
+
confidence=0.95, # Estimated
|
| 184 |
+
duration=result["metadata"]["duration_seconds"],
|
| 185 |
+
created_at=datetime.utcnow()
|
| 186 |
+
)
|
| 187 |
+
db.add(transcript)
|
| 188 |
+
db.commit()
|
| 189 |
+
db.refresh(transcript)
|
| 190 |
+
|
| 191 |
+
return result
|
| 192 |
+
|
| 193 |
+
except Exception as e:
|
| 194 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 195 |
+
finally:
|
| 196 |
+
# Cleanup
|
| 197 |
+
try:
|
| 198 |
+
os.unlink(tmp_path)
|
| 199 |
+
except:
|
| 200 |
+
pass
|
backend/app/api/routes/translation.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Translation API Routes
|
| 3 |
+
Endpoints for text and audio translation services
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
from typing import Optional, List
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from app.services.translation_service import get_translation_service
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
router = APIRouter(prefix="/translation", tags=["translation"])
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Request/Response Models
|
| 18 |
+
class TranslateTextRequest(BaseModel):
|
| 19 |
+
"""Request model for text translation."""
|
| 20 |
+
text: str = Field(..., min_length=1, max_length=5000, description="Text to translate")
|
| 21 |
+
source_lang: str = Field(..., description="Source language code (e.g., 'hi', 'en-US')")
|
| 22 |
+
target_lang: str = Field(..., description="Target language code (e.g., 'en', 'es')")
|
| 23 |
+
use_pivot: bool = Field(default=True, description="Use English as pivot for unsupported pairs")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TranslateTextResponse(BaseModel):
|
| 27 |
+
"""Response model for text translation."""
|
| 28 |
+
translated_text: str
|
| 29 |
+
source_lang: str
|
| 30 |
+
target_lang: str
|
| 31 |
+
source_text: str
|
| 32 |
+
processing_time: float
|
| 33 |
+
word_count: int
|
| 34 |
+
pivot_used: Optional[bool] = False
|
| 35 |
+
intermediate_text: Optional[str] = None
|
| 36 |
+
model_used: Optional[str] = None
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class LanguageInfo(BaseModel):
|
| 40 |
+
"""Language information model."""
|
| 41 |
+
code: str
|
| 42 |
+
name: str
|
| 43 |
+
flag: str
|
| 44 |
+
native: str
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TranslationPair(BaseModel):
|
| 48 |
+
"""Translation pair model."""
|
| 49 |
+
code: str
|
| 50 |
+
source: LanguageInfo
|
| 51 |
+
target: LanguageInfo
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class DetectLanguageResponse(BaseModel):
|
| 55 |
+
"""Response model for language detection."""
|
| 56 |
+
detected_language: str
|
| 57 |
+
confidence: float
|
| 58 |
+
language_info: Optional[dict] = None
|
| 59 |
+
all_probabilities: Optional[List[dict]] = None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Endpoints
|
| 63 |
+
@router.get("/languages", response_model=List[LanguageInfo])
|
| 64 |
+
async def get_supported_languages():
|
| 65 |
+
"""
|
| 66 |
+
Get list of all supported languages.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
List of supported languages with metadata
|
| 70 |
+
"""
|
| 71 |
+
service = get_translation_service()
|
| 72 |
+
return service.get_supported_languages()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@router.get("/pairs")
|
| 76 |
+
async def get_supported_pairs():
|
| 77 |
+
"""
|
| 78 |
+
Get list of all supported translation pairs.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List of supported source->target language pairs
|
| 82 |
+
"""
|
| 83 |
+
service = get_translation_service()
|
| 84 |
+
return {
|
| 85 |
+
"pairs": service.get_supported_pairs(),
|
| 86 |
+
"total": len(service.get_supported_pairs()),
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@router.post("/text", response_model=TranslateTextResponse)
|
| 91 |
+
async def translate_text(request: TranslateTextRequest):
|
| 92 |
+
"""
|
| 93 |
+
Translate text from source to target language.
|
| 94 |
+
|
| 95 |
+
- Uses Helsinki-NLP MarianMT models (~300MB per language pair)
|
| 96 |
+
- Supports pivot translation through English for unsupported pairs
|
| 97 |
+
- First request for a language pair may take longer (model loading)
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
request: Translation request with text and language codes
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
Translated text with metadata
|
| 104 |
+
"""
|
| 105 |
+
service = get_translation_service()
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
if request.use_pivot:
|
| 109 |
+
result = service.translate_with_pivot(
|
| 110 |
+
text=request.text,
|
| 111 |
+
source_lang=request.source_lang,
|
| 112 |
+
target_lang=request.target_lang,
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
result = service.translate_text(
|
| 116 |
+
text=request.text,
|
| 117 |
+
source_lang=request.source_lang,
|
| 118 |
+
target_lang=request.target_lang,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return TranslateTextResponse(**result)
|
| 122 |
+
|
| 123 |
+
except ValueError as e:
|
| 124 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"Translation error: {e}")
|
| 127 |
+
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@router.post("/detect", response_model=DetectLanguageResponse)
|
| 131 |
+
async def detect_language(text: str = Form(..., min_length=10, description="Text to analyze")):
|
| 132 |
+
"""
|
| 133 |
+
Detect the language of input text.
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
text: Text to analyze (minimum 10 characters for accuracy)
|
| 137 |
+
|
| 138 |
+
Returns:
|
| 139 |
+
Detected language with confidence score
|
| 140 |
+
"""
|
| 141 |
+
service = get_translation_service()
|
| 142 |
+
result = service.detect_language(text)
|
| 143 |
+
|
| 144 |
+
if result.get("error"):
|
| 145 |
+
raise HTTPException(status_code=400, detail=result["error"])
|
| 146 |
+
|
| 147 |
+
return DetectLanguageResponse(**result)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@router.get("/model-info")
|
| 151 |
+
async def get_model_info():
|
| 152 |
+
"""
|
| 153 |
+
Get information about loaded translation models.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Model loading status and supported pairs
|
| 157 |
+
"""
|
| 158 |
+
service = get_translation_service()
|
| 159 |
+
return service.get_model_info()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@router.post("/audio")
|
| 163 |
+
async def translate_audio(
|
| 164 |
+
file: UploadFile = File(..., description="Audio file to translate"),
|
| 165 |
+
source_lang: str = Form(..., description="Source language code"),
|
| 166 |
+
target_lang: str = Form(..., description="Target language code"),
|
| 167 |
+
generate_audio: bool = Form(default=True, description="Generate TTS output"),
|
| 168 |
+
):
|
| 169 |
+
"""
|
| 170 |
+
Full audio translation pipeline: STT → Translate → TTS
|
| 171 |
+
|
| 172 |
+
1. Transcribe audio using Whisper
|
| 173 |
+
2. Translate text using MarianMT
|
| 174 |
+
3. Optionally generate speech in target language
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
file: Audio file (WAV, MP3, etc.)
|
| 178 |
+
source_lang: Source language code
|
| 179 |
+
target_lang: Target language code
|
| 180 |
+
generate_audio: Whether to generate TTS output
|
| 181 |
+
|
| 182 |
+
Returns:
|
| 183 |
+
Transcription, translation, and optional audio response
|
| 184 |
+
"""
|
| 185 |
+
import tempfile
|
| 186 |
+
import os
|
| 187 |
+
from app.services.whisper_stt_service import get_whisper_stt_service
|
| 188 |
+
from app.services.edge_tts_service import get_edge_tts_service
|
| 189 |
+
|
| 190 |
+
translation_service = get_translation_service()
|
| 191 |
+
stt_service = get_whisper_stt_service()
|
| 192 |
+
tts_service = get_edge_tts_service()
|
| 193 |
+
|
| 194 |
+
# Save uploaded file
|
| 195 |
+
suffix = os.path.splitext(file.filename)[1] or ".wav"
|
| 196 |
+
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
|
| 197 |
+
content = await file.read()
|
| 198 |
+
tmp.write(content)
|
| 199 |
+
tmp_path = tmp.name
|
| 200 |
+
|
| 201 |
+
try:
|
| 202 |
+
# Step 1: Transcribe
|
| 203 |
+
transcription = stt_service.transcribe_file(tmp_path, language=source_lang)
|
| 204 |
+
source_text = transcription["text"]
|
| 205 |
+
|
| 206 |
+
if not source_text.strip():
|
| 207 |
+
raise HTTPException(status_code=400, detail="No speech detected in audio")
|
| 208 |
+
|
| 209 |
+
# Step 2: Translate
|
| 210 |
+
translation = translation_service.translate_with_pivot(
|
| 211 |
+
text=source_text,
|
| 212 |
+
source_lang=source_lang,
|
| 213 |
+
target_lang=target_lang,
|
| 214 |
+
)
|
| 215 |
+
translated_text = translation["translated_text"]
|
| 216 |
+
|
| 217 |
+
# Step 3: Generate TTS (optional)
|
| 218 |
+
audio_base64 = None
|
| 219 |
+
if generate_audio:
|
| 220 |
+
# Map language code to voice
|
| 221 |
+
voice_map = {
|
| 222 |
+
"en": "en-US-AriaNeural",
|
| 223 |
+
"hi": "hi-IN-SwaraNeural",
|
| 224 |
+
"es": "es-ES-ElviraNeural",
|
| 225 |
+
"fr": "fr-FR-DeniseNeural",
|
| 226 |
+
"de": "de-DE-KatjaNeural",
|
| 227 |
+
"zh": "zh-CN-XiaoxiaoNeural",
|
| 228 |
+
"ja": "ja-JP-NanamiNeural",
|
| 229 |
+
"ko": "ko-KR-SunHiNeural",
|
| 230 |
+
"ar": "ar-SA-ZariyahNeural",
|
| 231 |
+
"ru": "ru-RU-SvetlanaNeural",
|
| 232 |
+
}
|
| 233 |
+
target_code = target_lang.split("-")[0].lower()
|
| 234 |
+
voice = voice_map.get(target_code, "en-US-AriaNeural")
|
| 235 |
+
|
| 236 |
+
audio_bytes = tts_service.synthesize_sync(translated_text, voice=voice)
|
| 237 |
+
|
| 238 |
+
import base64
|
| 239 |
+
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
|
| 240 |
+
|
| 241 |
+
return {
|
| 242 |
+
"source_text": source_text,
|
| 243 |
+
"translated_text": translated_text,
|
| 244 |
+
"source_lang": source_lang,
|
| 245 |
+
"target_lang": target_lang,
|
| 246 |
+
"transcription_time": transcription["processing_time"],
|
| 247 |
+
"translation_time": translation["processing_time"],
|
| 248 |
+
"audio_base64": audio_base64,
|
| 249 |
+
"audio_format": "mp3" if audio_base64 else None,
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
except HTTPException:
|
| 253 |
+
raise
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.error(f"Audio translation failed: {e}")
|
| 256 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 257 |
+
finally:
|
| 258 |
+
try:
|
| 259 |
+
os.unlink(tmp_path)
|
| 260 |
+
except:
|
| 261 |
+
pass
|
backend/app/api/routes/tts.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text-to-Speech API Router
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Optional
|
| 8 |
+
from fastapi import APIRouter, HTTPException, Depends, Response, Request
|
| 9 |
+
from fastapi.responses import StreamingResponse
|
| 10 |
+
from io import BytesIO
|
| 11 |
+
|
| 12 |
+
from ...core.limiter import limiter
|
| 13 |
+
|
| 14 |
+
from ...services.tts_service import get_tts_service, TTSService
|
| 15 |
+
from ...schemas.tts import (
|
| 16 |
+
SynthesisRequest,
|
| 17 |
+
SynthesisResponse,
|
| 18 |
+
VoiceInfo,
|
| 19 |
+
VoiceListResponse,
|
| 20 |
+
VoicePreviewRequest,
|
| 21 |
+
)
|
| 22 |
+
from ...core.config import get_settings
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
router = APIRouter(prefix="/tts", tags=["Text-to-Speech"])
|
| 26 |
+
settings = get_settings()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@router.get("/voices", response_model=VoiceListResponse)
|
| 30 |
+
async def get_voices(
|
| 31 |
+
language: Optional[str] = None,
|
| 32 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Get list of available TTS voices
|
| 36 |
+
|
| 37 |
+
Optionally filter by language code (e.g., "en-US", "es", "fr")
|
| 38 |
+
"""
|
| 39 |
+
return await tts_service.get_voices(language_code=language)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@router.get("/voices/{language}", response_model=VoiceListResponse)
|
| 43 |
+
async def get_voices_by_language(
|
| 44 |
+
language: str,
|
| 45 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Get voices for a specific language
|
| 49 |
+
"""
|
| 50 |
+
if language not in settings.supported_languages_list:
|
| 51 |
+
# Try partial match (e.g., "en" matches "en-US", "en-GB")
|
| 52 |
+
partial_matches = [l for l in settings.supported_languages_list if l.startswith(language)]
|
| 53 |
+
if not partial_matches:
|
| 54 |
+
raise HTTPException(
|
| 55 |
+
status_code=400,
|
| 56 |
+
detail=f"Unsupported language: {language}"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return await tts_service.get_voices(language_code=language)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@router.post("/synthesize", response_model=SynthesisResponse)
|
| 63 |
+
@limiter.limit("10/minute")
|
| 64 |
+
async def synthesize_speech(
|
| 65 |
+
request: Request,
|
| 66 |
+
request_body: SynthesisRequest,
|
| 67 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Synthesize text to speech
|
| 71 |
+
|
| 72 |
+
Returns base64-encoded audio content along with metadata.
|
| 73 |
+
Decode the audio_content field to get the audio bytes.
|
| 74 |
+
"""
|
| 75 |
+
# Validate text length
|
| 76 |
+
if len(request_body.text) > 5000:
|
| 77 |
+
raise HTTPException(
|
| 78 |
+
status_code=400,
|
| 79 |
+
detail="Text too long. Maximum 5000 characters."
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Validate language
|
| 83 |
+
lang_base = request_body.language.split("-")[0] if "-" in request_body.language else request_body.language
|
| 84 |
+
supported_bases = [l.split("-")[0] for l in settings.supported_languages_list]
|
| 85 |
+
if lang_base not in supported_bases:
|
| 86 |
+
raise HTTPException(
|
| 87 |
+
status_code=400,
|
| 88 |
+
detail=f"Unsupported language: {request_body.language}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
result = await tts_service.synthesize(request_body)
|
| 93 |
+
return result
|
| 94 |
+
except ValueError as e:
|
| 95 |
+
logger.error(f"Synthesis validation error: {e}")
|
| 96 |
+
raise HTTPException(status_code=400, detail=str(e))
|
| 97 |
+
except Exception as e:
|
| 98 |
+
logger.exception(f"Synthesis failed: {e}")
|
| 99 |
+
raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
@router.post("/stream")
|
| 103 |
+
async def stream_speech(
|
| 104 |
+
request: SynthesisRequest,
|
| 105 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Stream text-to-speech audio
|
| 109 |
+
|
| 110 |
+
Returns a chunked audio stream (audio/mpeg) for immediate playback.
|
| 111 |
+
Best for long text to reduce latency (TTFB).
|
| 112 |
+
"""
|
| 113 |
+
try:
|
| 114 |
+
return StreamingResponse(
|
| 115 |
+
tts_service.synthesize_stream(request),
|
| 116 |
+
media_type="audio/mpeg"
|
| 117 |
+
)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.exception(f"Streaming synthesis failed: {e}")
|
| 120 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@router.post("/ssml")
|
| 124 |
+
async def synthesize_ssml(
|
| 125 |
+
text: str,
|
| 126 |
+
voice: str = "en-US-AriaNeural",
|
| 127 |
+
rate: str = "medium",
|
| 128 |
+
pitch: str = "medium",
|
| 129 |
+
emphasis: Optional[str] = None,
|
| 130 |
+
auto_breaks: bool = True,
|
| 131 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 132 |
+
):
|
| 133 |
+
"""
|
| 134 |
+
Synthesize speech with SSML prosody control
|
| 135 |
+
|
| 136 |
+
Supports advanced speech customization:
|
| 137 |
+
- rate: 'x-slow', 'slow', 'medium', 'fast', 'x-fast'
|
| 138 |
+
- pitch: 'x-low', 'low', 'medium', 'high', 'x-high'
|
| 139 |
+
- emphasis: 'reduced', 'moderate', 'strong'
|
| 140 |
+
- auto_breaks: Add natural pauses at punctuation
|
| 141 |
+
|
| 142 |
+
Returns audio/mpeg stream.
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
from ...services.edge_tts_service import get_edge_tts_service
|
| 146 |
+
edge_service = get_edge_tts_service()
|
| 147 |
+
|
| 148 |
+
# Build SSML
|
| 149 |
+
ssml = edge_service.build_ssml(
|
| 150 |
+
text=text,
|
| 151 |
+
voice=voice,
|
| 152 |
+
rate=rate,
|
| 153 |
+
pitch=pitch,
|
| 154 |
+
emphasis=emphasis,
|
| 155 |
+
breaks=auto_breaks
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Synthesize
|
| 159 |
+
audio_bytes = await edge_service.synthesize_ssml(ssml, voice)
|
| 160 |
+
|
| 161 |
+
return Response(
|
| 162 |
+
content=audio_bytes,
|
| 163 |
+
media_type="audio/mpeg",
|
| 164 |
+
headers={"Content-Disposition": "inline; filename=speech.mp3"}
|
| 165 |
+
)
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.exception(f"SSML synthesis failed: {e}")
|
| 168 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@router.post("/synthesize/audio")
|
| 172 |
+
async def synthesize_audio_file(
|
| 173 |
+
request: SynthesisRequest,
|
| 174 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Synthesize text and return audio file directly
|
| 178 |
+
|
| 179 |
+
Returns the audio file as a downloadable stream.
|
| 180 |
+
"""
|
| 181 |
+
try:
|
| 182 |
+
result = await tts_service.synthesize(request)
|
| 183 |
+
|
| 184 |
+
# Decode base64 audio
|
| 185 |
+
audio_bytes = base64.b64decode(result.audio_content)
|
| 186 |
+
|
| 187 |
+
# Determine content type
|
| 188 |
+
content_types = {
|
| 189 |
+
"MP3": "audio/mpeg",
|
| 190 |
+
"LINEAR16": "audio/wav",
|
| 191 |
+
"OGG_OPUS": "audio/ogg",
|
| 192 |
+
}
|
| 193 |
+
content_type = content_types.get(result.encoding, "audio/mpeg")
|
| 194 |
+
|
| 195 |
+
# Return as streaming response
|
| 196 |
+
return StreamingResponse(
|
| 197 |
+
BytesIO(audio_bytes),
|
| 198 |
+
media_type=content_type,
|
| 199 |
+
headers={
|
| 200 |
+
"Content-Disposition": f'attachment; filename="speech.{result.encoding.lower()}"',
|
| 201 |
+
"Content-Length": str(result.audio_size),
|
| 202 |
+
}
|
| 203 |
+
)
|
| 204 |
+
except Exception as e:
|
| 205 |
+
logger.exception(f"Audio synthesis failed: {e}")
|
| 206 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@router.post("/preview")
|
| 210 |
+
async def preview_voice(
|
| 211 |
+
request: VoicePreviewRequest,
|
| 212 |
+
tts_service: TTSService = Depends(get_tts_service),
|
| 213 |
+
):
|
| 214 |
+
"""
|
| 215 |
+
Generate a short preview of a voice
|
| 216 |
+
|
| 217 |
+
Returns a small audio sample for voice selection UI.
|
| 218 |
+
"""
|
| 219 |
+
# Find the voice to get its language
|
| 220 |
+
voices = tts_service.get_voices().voices
|
| 221 |
+
voice_info = next((v for v in voices if v.name == request.voice), None)
|
| 222 |
+
|
| 223 |
+
if not voice_info:
|
| 224 |
+
raise HTTPException(status_code=404, detail=f"Voice not found: {request.voice}")
|
| 225 |
+
|
| 226 |
+
# Create synthesis request with preview text
|
| 227 |
+
synth_request = SynthesisRequest(
|
| 228 |
+
text=request.text or "Hello! This is a preview of my voice.",
|
| 229 |
+
language=voice_info.language_code,
|
| 230 |
+
voice=request.voice,
|
| 231 |
+
audio_encoding="MP3",
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
result = tts_service.synthesize(synth_request)
|
| 236 |
+
|
| 237 |
+
# Return audio directly
|
| 238 |
+
audio_bytes = base64.b64decode(result.audio_content)
|
| 239 |
+
return StreamingResponse(
|
| 240 |
+
BytesIO(audio_bytes),
|
| 241 |
+
media_type="audio/mpeg",
|
| 242 |
+
)
|
| 243 |
+
except Exception as e:
|
| 244 |
+
logger.exception(f"Preview failed: {e}")
|
| 245 |
+
raise HTTPException(status_code=500, detail=str(e))
|
backend/app/api/routes/ws.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WebSocket Router for Real-Time Transcription
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import json
|
| 7 |
+
from typing import Dict, Optional
|
| 8 |
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
| 9 |
+
|
| 10 |
+
from app.core.ws_security import (
|
| 11 |
+
validate_ws_origin,
|
| 12 |
+
authenticate_websocket,
|
| 13 |
+
ws_rate_limiter
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
router = APIRouter(prefix="/ws", tags=["WebSocket"])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ConnectionManager:
|
| 21 |
+
"""Manages active WebSocket connections"""
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.active_connections: Dict[str, WebSocket] = {}
|
| 25 |
+
self.user_ids: Dict[str, Optional[int]] = {} # Track authenticated users
|
| 26 |
+
|
| 27 |
+
async def connect(
|
| 28 |
+
self,
|
| 29 |
+
client_id: str,
|
| 30 |
+
websocket: WebSocket,
|
| 31 |
+
user_id: Optional[int] = None
|
| 32 |
+
) -> bool:
|
| 33 |
+
"""
|
| 34 |
+
Connect a client after validation.
|
| 35 |
+
Returns True if connection accepted, False if rejected.
|
| 36 |
+
"""
|
| 37 |
+
# Validate origin
|
| 38 |
+
if not validate_ws_origin(websocket):
|
| 39 |
+
logger.warning(f"WebSocket rejected for {client_id}: invalid origin")
|
| 40 |
+
await websocket.close(code=1008) # Policy Violation
|
| 41 |
+
return False
|
| 42 |
+
|
| 43 |
+
await websocket.accept()
|
| 44 |
+
self.active_connections[client_id] = websocket
|
| 45 |
+
self.user_ids[client_id] = user_id
|
| 46 |
+
logger.info(f"Client {client_id} connected (user_id={user_id})")
|
| 47 |
+
return True
|
| 48 |
+
|
| 49 |
+
def disconnect(self, client_id: str):
|
| 50 |
+
if client_id in self.active_connections:
|
| 51 |
+
del self.active_connections[client_id]
|
| 52 |
+
self.user_ids.pop(client_id, None)
|
| 53 |
+
ws_rate_limiter.cleanup(client_id)
|
| 54 |
+
logger.info(f"Client {client_id} disconnected")
|
| 55 |
+
|
| 56 |
+
async def send_json(self, client_id: str, data: dict):
|
| 57 |
+
if client_id in self.active_connections:
|
| 58 |
+
await self.active_connections[client_id].send_json(data)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
manager = ConnectionManager()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@router.websocket("/transcription/{client_id}")
|
| 65 |
+
async def websocket_transcription(
|
| 66 |
+
websocket: WebSocket,
|
| 67 |
+
client_id: str,
|
| 68 |
+
token: Optional[str] = Query(None)
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Real-time streaming transcription via WebSocket with VAD.
|
| 72 |
+
|
| 73 |
+
Optional auth via query param: ws://host/ws/transcription/{id}?token=jwt_token
|
| 74 |
+
"""
|
| 75 |
+
# Authenticate (optional for demo, but logged)
|
| 76 |
+
user_id = await authenticate_websocket(websocket, token)
|
| 77 |
+
|
| 78 |
+
if not await manager.connect(client_id, websocket, user_id):
|
| 79 |
+
return # Connection rejected
|
| 80 |
+
|
| 81 |
+
from app.services.ws_stt_service import StreamManager, transcribe_buffer
|
| 82 |
+
|
| 83 |
+
stream_manager = StreamManager(websocket)
|
| 84 |
+
|
| 85 |
+
async def handle_transcription(audio_bytes: bytes):
|
| 86 |
+
"""Callback for processing speech segments."""
|
| 87 |
+
# Check rate limit
|
| 88 |
+
if not ws_rate_limiter.check_rate(client_id):
|
| 89 |
+
await manager.send_json(client_id, {"error": "Rate limit exceeded"})
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
# Check message size
|
| 93 |
+
if not ws_rate_limiter.check_size(audio_bytes):
|
| 94 |
+
await manager.send_json(client_id, {"error": "Message too large"})
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
# Send processing status
|
| 99 |
+
await manager.send_json(client_id, {"status": "processing"})
|
| 100 |
+
|
| 101 |
+
# Transcribe
|
| 102 |
+
result = await transcribe_buffer(audio_bytes)
|
| 103 |
+
text = result.get("text", "").strip()
|
| 104 |
+
|
| 105 |
+
if text:
|
| 106 |
+
# Send result
|
| 107 |
+
await manager.send_json(client_id, {
|
| 108 |
+
"text": text,
|
| 109 |
+
"is_final": True,
|
| 110 |
+
"status": "complete"
|
| 111 |
+
})
|
| 112 |
+
logger.info(f"Transcribed: {text}")
|
| 113 |
+
except Exception as e:
|
| 114 |
+
logger.error(f"Transcription callback error: {e}")
|
| 115 |
+
await manager.send_json(client_id, {"error": str(e)})
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
# Start processing loop
|
| 119 |
+
await stream_manager.process_stream(handle_transcription)
|
| 120 |
+
|
| 121 |
+
except WebSocketDisconnect:
|
| 122 |
+
manager.disconnect(client_id)
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error(f"WebSocket error: {e}")
|
| 125 |
+
try:
|
| 126 |
+
await manager.send_json(client_id, {"error": str(e)})
|
| 127 |
+
except:
|
| 128 |
+
pass
|
| 129 |
+
manager.disconnect(client_id)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
@router.websocket("/tts/{client_id}")
|
| 133 |
+
async def websocket_tts(
|
| 134 |
+
websocket: WebSocket,
|
| 135 |
+
client_id: str,
|
| 136 |
+
token: Optional[str] = Query(None)
|
| 137 |
+
):
|
| 138 |
+
"""
|
| 139 |
+
Real-time Text-to-Speech via WebSocket
|
| 140 |
+
|
| 141 |
+
Protocol:
|
| 142 |
+
- Client sends: JSON {"text": "...", "voice": "...", "rate": "...", "pitch": "..."}
|
| 143 |
+
- Server sends: Binary audio chunks (MP3) followed by JSON {"status": "complete"}
|
| 144 |
+
|
| 145 |
+
Optional auth via query param: ws://host/ws/tts/{id}?token=jwt_token
|
| 146 |
+
This achieves <500ms TTFB by streaming as chunks are generated.
|
| 147 |
+
"""
|
| 148 |
+
# Authenticate (optional for demo, but logged)
|
| 149 |
+
user_id = await authenticate_websocket(websocket, token)
|
| 150 |
+
|
| 151 |
+
if not await manager.connect(client_id, websocket, user_id):
|
| 152 |
+
return # Connection rejected
|
| 153 |
+
|
| 154 |
+
try:
|
| 155 |
+
import edge_tts
|
| 156 |
+
|
| 157 |
+
while True:
|
| 158 |
+
# Receive synthesis request
|
| 159 |
+
data = await websocket.receive_json()
|
| 160 |
+
|
| 161 |
+
text = data.get("text", "")
|
| 162 |
+
voice = data.get("voice", "en-US-AriaNeural")
|
| 163 |
+
rate = data.get("rate", "+0%")
|
| 164 |
+
pitch = data.get("pitch", "+0Hz")
|
| 165 |
+
|
| 166 |
+
if not text:
|
| 167 |
+
await websocket.send_json({"error": "No text provided"})
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
logger.info(f"WebSocket TTS: Synthesizing '{text[:50]}...' with {voice}")
|
| 171 |
+
|
| 172 |
+
# Stream audio chunks directly
|
| 173 |
+
import time
|
| 174 |
+
start_time = time.time()
|
| 175 |
+
first_chunk_sent = False
|
| 176 |
+
total_bytes = 0
|
| 177 |
+
|
| 178 |
+
communicate = edge_tts.Communicate(text, voice, rate=rate, pitch=pitch)
|
| 179 |
+
|
| 180 |
+
async for chunk in communicate.stream():
|
| 181 |
+
if chunk["type"] == "audio":
|
| 182 |
+
await websocket.send_bytes(chunk["data"])
|
| 183 |
+
total_bytes += len(chunk["data"])
|
| 184 |
+
|
| 185 |
+
if not first_chunk_sent:
|
| 186 |
+
ttfb = (time.time() - start_time) * 1000
|
| 187 |
+
logger.info(f"WebSocket TTS TTFB: {ttfb:.0f}ms")
|
| 188 |
+
first_chunk_sent = True
|
| 189 |
+
|
| 190 |
+
# Send completion marker
|
| 191 |
+
total_time = time.time() - start_time
|
| 192 |
+
await websocket.send_json({
|
| 193 |
+
"status": "complete",
|
| 194 |
+
"total_bytes": total_bytes,
|
| 195 |
+
"total_time_ms": round(total_time * 1000),
|
| 196 |
+
"ttfb_ms": round(ttfb) if first_chunk_sent else None
|
| 197 |
+
})
|
| 198 |
+
|
| 199 |
+
except WebSocketDisconnect:
|
| 200 |
+
manager.disconnect(client_id)
|
| 201 |
+
except Exception as e:
|
| 202 |
+
logger.error(f"WebSocket TTS error: {e}")
|
| 203 |
+
try:
|
| 204 |
+
await websocket.send_json({"error": str(e)})
|
| 205 |
+
except:
|
| 206 |
+
pass
|
| 207 |
+
manager.disconnect(client_id)
|
| 208 |
+
|
backend/app/core/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge Core Package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .config import get_settings, Settings, LANGUAGE_METADATA
|
| 6 |
+
|
| 7 |
+
__all__ = ["get_settings", "Settings", "LANGUAGE_METADATA"]
|
backend/app/core/config.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge Configuration
|
| 3 |
+
Pydantic Settings for application configuration
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
from typing import List
|
| 8 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 9 |
+
from pydantic import Field
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Settings(BaseSettings):
|
| 13 |
+
"""Application settings loaded from environment variables"""
|
| 14 |
+
|
| 15 |
+
model_config = SettingsConfigDict(
|
| 16 |
+
env_file=".env",
|
| 17 |
+
env_file_encoding="utf-8",
|
| 18 |
+
case_sensitive=False,
|
| 19 |
+
extra="allow", # Allow extra env vars without error
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Application
|
| 23 |
+
app_name: str = "VoiceForge"
|
| 24 |
+
app_version: str = "1.0.0"
|
| 25 |
+
debug: bool = False
|
| 26 |
+
|
| 27 |
+
# API Server
|
| 28 |
+
api_host: str = "0.0.0.0"
|
| 29 |
+
api_port: int = 8000
|
| 30 |
+
|
| 31 |
+
# Database
|
| 32 |
+
database_url: str = Field(
|
| 33 |
+
default="sqlite:///./voiceforge.db",
|
| 34 |
+
description="Database connection URL (SQLite for dev, PostgreSQL for prod)"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
# Redis
|
| 38 |
+
redis_url: str = Field(
|
| 39 |
+
default="redis://localhost:6379/0",
|
| 40 |
+
description="Redis connection URL for caching and Celery"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Google Cloud
|
| 44 |
+
google_application_credentials: str = Field(
|
| 45 |
+
default="./credentials/google-cloud-key.json",
|
| 46 |
+
description="Path to Google Cloud service account JSON key"
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# AI Services Configuration
|
| 50 |
+
use_local_services: bool = Field(
|
| 51 |
+
default=True,
|
| 52 |
+
description="Use local free services (Whisper + EdgeTTS) instead of Google Cloud"
|
| 53 |
+
)
|
| 54 |
+
whisper_model: str = Field(
|
| 55 |
+
default="small",
|
| 56 |
+
description="Whisper model size (tiny, base, small, medium, large-v3)"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Security
|
| 60 |
+
secret_key: str = Field(
|
| 61 |
+
default="your-super-secret-key-change-in-production",
|
| 62 |
+
description="Secret key for JWT encoding"
|
| 63 |
+
)
|
| 64 |
+
access_token_expire_minutes: int = 30
|
| 65 |
+
algorithm: str = "HS256"
|
| 66 |
+
hf_token: str | None = Field(default=None, description="Hugging Face Token for Diarization")
|
| 67 |
+
|
| 68 |
+
# File Storage
|
| 69 |
+
upload_dir: str = "./uploads"
|
| 70 |
+
max_audio_duration_seconds: int = 600 # 10 minutes
|
| 71 |
+
max_upload_size_mb: int = 50
|
| 72 |
+
|
| 73 |
+
# Supported Languages
|
| 74 |
+
supported_languages: str = "en-US,en-GB,es-ES,es-MX,fr-FR,de-DE,ja-JP,ko-KR,zh-CN,hi-IN"
|
| 75 |
+
|
| 76 |
+
# Audio Formats
|
| 77 |
+
supported_audio_formats: str = "wav,mp3,m4a,flac,ogg,webm"
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def supported_languages_list(self) -> List[str]:
|
| 81 |
+
"""Get supported languages as a list"""
|
| 82 |
+
return [lang.strip() for lang in self.supported_languages.split(",")]
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def supported_audio_formats_list(self) -> List[str]:
|
| 86 |
+
"""Get supported audio formats as a list"""
|
| 87 |
+
return [fmt.strip() for fmt in self.supported_audio_formats.split(",")]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Language metadata for UI display
|
| 91 |
+
LANGUAGE_METADATA = {
|
| 92 |
+
"en-US": {"name": "English (US)", "flag": "🇺🇸", "native": "English"},
|
| 93 |
+
"en-GB": {"name": "English (UK)", "flag": "🇬🇧", "native": "English"},
|
| 94 |
+
"es-ES": {"name": "Spanish (Spain)", "flag": "🇪🇸", "native": "Español"},
|
| 95 |
+
"es-MX": {"name": "Spanish (Mexico)", "flag": "🇲🇽", "native": "Español"},
|
| 96 |
+
"fr-FR": {"name": "French", "flag": "🇫🇷", "native": "Français"},
|
| 97 |
+
"de-DE": {"name": "German", "flag": "🇩🇪", "native": "Deutsch"},
|
| 98 |
+
"ja-JP": {"name": "Japanese", "flag": "🇯🇵", "native": "日本語"},
|
| 99 |
+
"ko-KR": {"name": "Korean", "flag": "🇰🇷", "native": "한국어"},
|
| 100 |
+
"zh-CN": {"name": "Chinese (Mandarin)", "flag": "🇨🇳", "native": "中文"},
|
| 101 |
+
"hi-IN": {"name": "Hindi", "flag": "🇮🇳", "native": "हिन्दी"},
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@lru_cache
|
| 106 |
+
def get_settings() -> Settings:
|
| 107 |
+
"""Get cached settings instance"""
|
| 108 |
+
return Settings()
|
backend/app/core/limiter.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from slowapi import Limiter
|
| 3 |
+
from slowapi.util import get_remote_address
|
| 4 |
+
from slowapi.errors import RateLimitExceeded
|
| 5 |
+
|
| 6 |
+
# Initialize Limiter
|
| 7 |
+
# Use in-memory storage for local dev (Redis for production)
|
| 8 |
+
redis_url = os.getenv("REDIS_URL")
|
| 9 |
+
|
| 10 |
+
# For local testing without Redis, use memory storage
|
| 11 |
+
if redis_url and redis_url.strip():
|
| 12 |
+
try:
|
| 13 |
+
import redis
|
| 14 |
+
r = redis.from_url(redis_url)
|
| 15 |
+
r.ping() # Test connection
|
| 16 |
+
storage_uri = redis_url
|
| 17 |
+
except Exception:
|
| 18 |
+
# Redis not available, fall back to memory
|
| 19 |
+
storage_uri = "memory://"
|
| 20 |
+
else:
|
| 21 |
+
storage_uri = "memory://"
|
| 22 |
+
|
| 23 |
+
limiter = Limiter(
|
| 24 |
+
key_func=get_remote_address,
|
| 25 |
+
storage_uri=storage_uri,
|
| 26 |
+
default_limits=["60/minute"] # Global limit: 60 req/min per IP
|
| 27 |
+
)
|
backend/app/core/middleware.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Rate Limiting Middleware
|
| 3 |
+
Uses Redis to track and limit request rates per IP address.
|
| 4 |
+
Pure ASGI implementation to avoid BaseHTTPMiddleware issues.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import redis
|
| 9 |
+
from starlette.responses import JSONResponse
|
| 10 |
+
from starlette.types import ASGIApp, Scope, Receive, Send
|
| 11 |
+
from ..core.config import get_settings
|
| 12 |
+
|
| 13 |
+
settings = get_settings()
|
| 14 |
+
|
| 15 |
+
class RateLimitMiddleware:
|
| 16 |
+
def __init__(self, app: ASGIApp):
|
| 17 |
+
self.app = app
|
| 18 |
+
# Hardcoded or from settings (bypassing constructor arg issue)
|
| 19 |
+
self.requests_per_minute = 60
|
| 20 |
+
self.window_size = 60 # seconds
|
| 21 |
+
|
| 22 |
+
# Connect to Redis
|
| 23 |
+
try:
|
| 24 |
+
self.redis_client = redis.from_url(settings.redis_url)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"⚠️ Rate limiter disabled: Could not connect to Redis ({e})")
|
| 27 |
+
self.redis_client = None
|
| 28 |
+
|
| 29 |
+
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
| 30 |
+
# Skip if not HTTP
|
| 31 |
+
if scope["type"] != "http":
|
| 32 |
+
await self.app(scope, receive, send)
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
# Skip rate limiting for non-API routes or if Redis is down
|
| 36 |
+
path = scope.get("path", "")
|
| 37 |
+
if not path.startswith("/api/") or self.redis_client is None:
|
| 38 |
+
await self.app(scope, receive, send)
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
# Get client IP
|
| 42 |
+
client = scope.get("client")
|
| 43 |
+
client_ip = client[0] if client else "unknown"
|
| 44 |
+
key = f"rate_limit:{client_ip}"
|
| 45 |
+
|
| 46 |
+
try:
|
| 47 |
+
# Simple fixed window counter
|
| 48 |
+
current_count = self.redis_client.incr(key)
|
| 49 |
+
|
| 50 |
+
# Set expiry on first request
|
| 51 |
+
if current_count == 1:
|
| 52 |
+
self.redis_client.expire(key, self.window_size)
|
| 53 |
+
|
| 54 |
+
if current_count > self.requests_per_minute:
|
| 55 |
+
response = JSONResponse(
|
| 56 |
+
status_code=429,
|
| 57 |
+
content={
|
| 58 |
+
"detail": "Too many requests",
|
| 59 |
+
"retry_after": self.window_size
|
| 60 |
+
},
|
| 61 |
+
headers={"Retry-After": str(self.window_size)}
|
| 62 |
+
)
|
| 63 |
+
await response(scope, receive, send)
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
except redis.RedisError:
|
| 67 |
+
# Fail open if Redis has issues during request
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
await self.app(scope, receive, send)
|
backend/app/core/request_size_middleware.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Request Size Limiting Middleware
|
| 3 |
+
Prevents large request bodies from consuming excessive memory.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 9 |
+
from starlette.requests import Request
|
| 10 |
+
from starlette.responses import Response
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# Default max body size: 50MB (configurable via env)
|
| 15 |
+
DEFAULT_MAX_BODY_SIZE = 50 * 1024 * 1024 # 50 MB
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
|
| 19 |
+
"""
|
| 20 |
+
ASGI Middleware to limit request body size.
|
| 21 |
+
|
| 22 |
+
Checks Content-Length header and rejects requests that exceed the limit
|
| 23 |
+
with a 413 Payload Too Large response.
|
| 24 |
+
|
| 25 |
+
Note: This checks Content-Length before reading the body.
|
| 26 |
+
For chunked transfer encoding, consider additional safeguards.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, app, max_body_size: int = None):
|
| 30 |
+
super().__init__(app)
|
| 31 |
+
self.max_body_size = max_body_size or int(
|
| 32 |
+
os.getenv("MAX_REQUEST_BODY_SIZE", DEFAULT_MAX_BODY_SIZE)
|
| 33 |
+
)
|
| 34 |
+
logger.info(f"Request size limit: {self.max_body_size / 1024 / 1024:.1f} MB")
|
| 35 |
+
|
| 36 |
+
async def dispatch(self, request: Request, call_next):
|
| 37 |
+
# Skip for WebSocket upgrades
|
| 38 |
+
if request.headers.get("upgrade", "").lower() == "websocket":
|
| 39 |
+
return await call_next(request)
|
| 40 |
+
|
| 41 |
+
# Check Content-Length header
|
| 42 |
+
content_length = request.headers.get("content-length")
|
| 43 |
+
|
| 44 |
+
if content_length:
|
| 45 |
+
try:
|
| 46 |
+
size = int(content_length)
|
| 47 |
+
if size > self.max_body_size:
|
| 48 |
+
logger.warning(
|
| 49 |
+
f"Request too large: {size / 1024 / 1024:.1f} MB "
|
| 50 |
+
f"(limit: {self.max_body_size / 1024 / 1024:.1f} MB) "
|
| 51 |
+
f"from {request.client.host if request.client else 'unknown'}"
|
| 52 |
+
)
|
| 53 |
+
return Response(
|
| 54 |
+
content="Request body too large",
|
| 55 |
+
status_code=413,
|
| 56 |
+
media_type="text/plain"
|
| 57 |
+
)
|
| 58 |
+
except ValueError:
|
| 59 |
+
pass # Invalid Content-Length, let the server handle it
|
| 60 |
+
|
| 61 |
+
return await call_next(request)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class StreamingSizeValidator:
|
| 65 |
+
"""
|
| 66 |
+
Utility for validating file upload sizes during streaming read.
|
| 67 |
+
|
| 68 |
+
Use with SpooledTemporaryFile for memory-efficient large file handling:
|
| 69 |
+
|
| 70 |
+
validator = StreamingSizeValidator(max_size=100 * 1024 * 1024)
|
| 71 |
+
with SpooledTemporaryFile(max_size=5*1024*1024) as tmp:
|
| 72 |
+
async for chunk in file.stream():
|
| 73 |
+
validator.add(len(chunk)) # Raises if exceeded
|
| 74 |
+
tmp.write(chunk)
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, max_size: int):
|
| 78 |
+
self.max_size = max_size
|
| 79 |
+
self.current_size = 0
|
| 80 |
+
|
| 81 |
+
def add(self, chunk_size: int):
|
| 82 |
+
"""Add chunk size and check limit. Raises ValueError if exceeded."""
|
| 83 |
+
self.current_size += chunk_size
|
| 84 |
+
if self.current_size > self.max_size:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
f"Upload exceeds size limit: {self.current_size} > {self.max_size}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def size(self) -> int:
|
| 91 |
+
return self.current_size
|
backend/app/core/security.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Security Utilities
|
| 3 |
+
Handles password hashing, JWT generation, and API key verification.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime, timedelta
|
| 7 |
+
from typing import Optional, Union, Any
|
| 8 |
+
from jose import jwt
|
| 9 |
+
from passlib.context import CryptContext
|
| 10 |
+
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
|
| 11 |
+
from fastapi import Depends, HTTPException, status
|
| 12 |
+
from sqlalchemy.orm import Session
|
| 13 |
+
|
| 14 |
+
from ..core.config import get_settings
|
| 15 |
+
from ..models import get_db, User, ApiKey
|
| 16 |
+
|
| 17 |
+
settings = get_settings()
|
| 18 |
+
|
| 19 |
+
# Password hashing (PBKDF2 is safer/easier on Windows than bcrypt sometimes)
|
| 20 |
+
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
| 21 |
+
|
| 22 |
+
# JWT configuration
|
| 23 |
+
SECRET_KEY = settings.secret_key
|
| 24 |
+
ALGORITHM = settings.algorithm
|
| 25 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
|
| 26 |
+
|
| 27 |
+
# OAuth2 scheme - auto_error=False allows API key fallback
|
| 28 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False)
|
| 29 |
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
| 33 |
+
return pwd_context.verify(plain_password, hashed_password)
|
| 34 |
+
|
| 35 |
+
def get_password_hash(password: str) -> str:
|
| 36 |
+
return pwd_context.hash(password)
|
| 37 |
+
|
| 38 |
+
def create_access_token(subject: Union[str, Any], expires_delta: timedelta = None) -> str:
|
| 39 |
+
if expires_delta:
|
| 40 |
+
expire = datetime.utcnow() + expires_delta
|
| 41 |
+
else:
|
| 42 |
+
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 43 |
+
|
| 44 |
+
to_encode = {"exp": expire, "sub": str(subject)}
|
| 45 |
+
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
| 46 |
+
return encoded_jwt
|
| 47 |
+
|
| 48 |
+
async def get_current_user(
|
| 49 |
+
token: str = Depends(oauth2_scheme),
|
| 50 |
+
db: Session = Depends(get_db)
|
| 51 |
+
) -> Optional[User]:
|
| 52 |
+
"""Validate JWT and return user. Returns None if token missing/invalid."""
|
| 53 |
+
if not token:
|
| 54 |
+
return None # Allow API key fallback
|
| 55 |
+
|
| 56 |
+
try:
|
| 57 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 58 |
+
user_id: str = payload.get("sub")
|
| 59 |
+
if user_id is None:
|
| 60 |
+
return None
|
| 61 |
+
except Exception:
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
user = db.query(User).filter(User.id == int(user_id)).first()
|
| 65 |
+
return user
|
| 66 |
+
|
| 67 |
+
async def get_current_active_user(current_user: Optional[User] = Depends(get_current_user)) -> User:
|
| 68 |
+
"""Get current active user. Raises 401 if not authenticated."""
|
| 69 |
+
if not current_user:
|
| 70 |
+
raise HTTPException(
|
| 71 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 72 |
+
detail="Not authenticated",
|
| 73 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 74 |
+
)
|
| 75 |
+
if not current_user.is_active:
|
| 76 |
+
raise HTTPException(status_code=400, detail="Inactive user")
|
| 77 |
+
return current_user
|
| 78 |
+
|
| 79 |
+
async def verify_api_key(
|
| 80 |
+
api_key: str = Depends(api_key_header),
|
| 81 |
+
db: Session = Depends(get_db)
|
| 82 |
+
) -> Optional[User]:
|
| 83 |
+
"""
|
| 84 |
+
Validate API key from X-API-Key header.
|
| 85 |
+
Returns the associated user if valid, else None (or raises if enforcing).
|
| 86 |
+
"""
|
| 87 |
+
if not api_key:
|
| 88 |
+
return None # Or raise if strict
|
| 89 |
+
|
| 90 |
+
key_record = db.query(ApiKey).filter(ApiKey.key == api_key, ApiKey.is_active == True).first()
|
| 91 |
+
|
| 92 |
+
if key_record:
|
| 93 |
+
# Update usage stats
|
| 94 |
+
key_record.last_used_at = datetime.utcnow()
|
| 95 |
+
db.commit()
|
| 96 |
+
return key_record.user
|
| 97 |
+
|
| 98 |
+
return None # Invalid key
|
| 99 |
+
|
| 100 |
+
def get_api_user_or_jwt_user(
|
| 101 |
+
api_key_user: Optional[User] = Depends(verify_api_key),
|
| 102 |
+
jwt_user: Optional[User] = Depends(get_current_user)
|
| 103 |
+
) -> User:
|
| 104 |
+
"""Allow access via either API Key or JWT"""
|
| 105 |
+
if api_key_user:
|
| 106 |
+
return api_key_user
|
| 107 |
+
if jwt_user:
|
| 108 |
+
return jwt_user
|
| 109 |
+
|
| 110 |
+
raise HTTPException(
|
| 111 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 112 |
+
detail="Not authenticated"
|
| 113 |
+
)
|
backend/app/core/security_encryption.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Field-level Encryption for SQLAlchemy Models.
|
| 3 |
+
|
| 4 |
+
Uses Fernet symmetric encryption from the `cryptography` library.
|
| 5 |
+
The ENCRYPTION_KEY should be a 32-byte base64-encoded key.
|
| 6 |
+
Generate one with: from cryptography.fernet import Fernet; print(Fernet.generate_key())
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import base64
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
from cryptography.fernet import Fernet, InvalidToken
|
| 15 |
+
from sqlalchemy import TypeDecorator, String
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
# --- Configuration ---
|
| 20 |
+
# IMPORTANT: Store this securely! In production, use secrets manager or env vars.
|
| 21 |
+
# Default key is for development ONLY - regenerate for production!
|
| 22 |
+
_DEFAULT_DEV_KEY = "VOICEFORGE_DEV_KEY_REPLACE_ME_NOW=" # Placeholder - NOT a valid key
|
| 23 |
+
|
| 24 |
+
def _get_encryption_key() -> bytes:
|
| 25 |
+
"""Get the encryption key from environment. Fail-closed in production."""
|
| 26 |
+
key_str = os.getenv("ENCRYPTION_KEY")
|
| 27 |
+
|
| 28 |
+
if key_str:
|
| 29 |
+
return key_str.encode()
|
| 30 |
+
|
| 31 |
+
# Check if running in production
|
| 32 |
+
is_production = os.getenv("ENVIRONMENT", "development").lower() == "production"
|
| 33 |
+
if is_production:
|
| 34 |
+
raise RuntimeError(
|
| 35 |
+
"ENCRYPTION_KEY environment variable must be set in production! "
|
| 36 |
+
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Development fallback - deterministic but INSECURE
|
| 40 |
+
logger.warning("⚠️ ENCRYPTION_KEY not set! Using DEV-ONLY fixed key. DO NOT USE IN PRODUCTION.")
|
| 41 |
+
# Fixed base64 key for dev consistency (32 bytes -> valid Fernet key)
|
| 42 |
+
return base64.urlsafe_b64encode(b"voiceforge_dev_key_32bytes_ok!")
|
| 43 |
+
|
| 44 |
+
# Cache the Fernet instance
|
| 45 |
+
_fernet: Optional[Fernet] = None
|
| 46 |
+
|
| 47 |
+
def get_fernet() -> Fernet:
|
| 48 |
+
"""Get or create the Fernet encryption instance."""
|
| 49 |
+
global _fernet
|
| 50 |
+
if _fernet is None:
|
| 51 |
+
key = _get_encryption_key()
|
| 52 |
+
_fernet = Fernet(key)
|
| 53 |
+
return _fernet
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# --- SQLAlchemy TypeDecorator ---
|
| 57 |
+
|
| 58 |
+
class EncryptedString(TypeDecorator):
|
| 59 |
+
"""
|
| 60 |
+
SQLAlchemy type that encrypts/decrypts string values transparently.
|
| 61 |
+
|
| 62 |
+
Usage:
|
| 63 |
+
class User(Base):
|
| 64 |
+
full_name = Column(EncryptedString(255), nullable=True)
|
| 65 |
+
|
| 66 |
+
The encrypted data is stored as a base64-encoded string in the database.
|
| 67 |
+
"""
|
| 68 |
+
impl = String
|
| 69 |
+
cache_ok = True
|
| 70 |
+
|
| 71 |
+
def __init__(self, length: int = 512, *args, **kwargs):
|
| 72 |
+
# Encrypted strings are longer than plaintext, so pad the length
|
| 73 |
+
super().__init__(length * 2, *args, **kwargs)
|
| 74 |
+
|
| 75 |
+
def process_bind_param(self, value, dialect):
|
| 76 |
+
"""Encrypt the value before storing in DB."""
|
| 77 |
+
if value is None:
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
fernet = get_fernet()
|
| 82 |
+
# Encode string to bytes, encrypt, then decode to string for storage
|
| 83 |
+
encrypted = fernet.encrypt(value.encode('utf-8'))
|
| 84 |
+
return encrypted.decode('utf-8')
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Encryption failed: {e}")
|
| 87 |
+
# In case of encryption failure, store plaintext (fail-open for dev)
|
| 88 |
+
# In production, you might want to raise instead
|
| 89 |
+
return value
|
| 90 |
+
|
| 91 |
+
def process_result_value(self, value, dialect):
|
| 92 |
+
"""Decrypt the value when reading from DB."""
|
| 93 |
+
if value is None:
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
fernet = get_fernet()
|
| 98 |
+
# Decode from storage string, decrypt, then decode to string
|
| 99 |
+
decrypted = fernet.decrypt(value.encode('utf-8'))
|
| 100 |
+
return decrypted.decode('utf-8')
|
| 101 |
+
except InvalidToken:
|
| 102 |
+
# Value might be plaintext (legacy data or encryption disabled)
|
| 103 |
+
logger.warning("Decryption failed - returning raw value (possible legacy data)")
|
| 104 |
+
return value
|
| 105 |
+
except Exception as e:
|
| 106 |
+
logger.error(f"Decryption failed: {e}")
|
| 107 |
+
return value
|
backend/app/core/security_headers.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from starlette.middleware.base import BaseHTTPMiddleware
|
| 2 |
+
from starlette.types import ASGIApp, Receive, Scope, Send
|
| 3 |
+
|
| 4 |
+
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
| 5 |
+
def __init__(self, app: ASGIApp):
|
| 6 |
+
super().__init__(app)
|
| 7 |
+
|
| 8 |
+
async def dispatch(self, request, call_next):
|
| 9 |
+
response = await call_next(request)
|
| 10 |
+
|
| 11 |
+
# Prevent Clickjacking
|
| 12 |
+
response.headers["X-Frame-Options"] = "DENY"
|
| 13 |
+
|
| 14 |
+
# Prevent MIME type sniffing
|
| 15 |
+
response.headers["X-Content-Type-Options"] = "nosniff"
|
| 16 |
+
|
| 17 |
+
# Enable XSS filtering in browser (legacy but good for depth)
|
| 18 |
+
response.headers["X-XSS-Protection"] = "1; mode=block"
|
| 19 |
+
|
| 20 |
+
# Strict Transport Security (HSTS)
|
| 21 |
+
# Enforce HTTPS. max-age=31536000 is 1 year.
|
| 22 |
+
# includeSubDomains applies to all subdomains.
|
| 23 |
+
# preload allows domain to be included in browser preload lists.
|
| 24 |
+
# NOTE: Only effective if served over HTTPS.
|
| 25 |
+
response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
|
| 26 |
+
|
| 27 |
+
# Content Security Policy (CSP)
|
| 28 |
+
# Very strict default: only allow content from self.
|
| 29 |
+
# This might need adjustment for Swagger UI (CDN assets) or other resources.
|
| 30 |
+
# For now, we allow 'unsafe-inline' and 'unsafe-eval' for Swagger UI compatibility if needed,
|
| 31 |
+
# but primarily 'self'.
|
| 32 |
+
response.headers["Content-Security-Policy"] = "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline';"
|
| 33 |
+
|
| 34 |
+
# Referrer Policy
|
| 35 |
+
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
| 36 |
+
|
| 37 |
+
return response
|
backend/app/core/ws_security.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WebSocket Security Utilities
|
| 3 |
+
Authentication and validation for WebSocket connections.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional, Set
|
| 9 |
+
from urllib.parse import urlparse
|
| 10 |
+
|
| 11 |
+
from fastapi import WebSocket, WebSocketException, status
|
| 12 |
+
from jose import jwt, JWTError
|
| 13 |
+
|
| 14 |
+
from .config import get_settings
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
settings = get_settings()
|
| 18 |
+
|
| 19 |
+
# Allowed origins for WebSocket connections
|
| 20 |
+
_allowed_ws_origins: Optional[Set[str]] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_allowed_origins() -> Set[str]:
|
| 24 |
+
"""Get set of allowed WebSocket origins."""
|
| 25 |
+
global _allowed_ws_origins
|
| 26 |
+
if _allowed_ws_origins is None:
|
| 27 |
+
origins_str = os.getenv(
|
| 28 |
+
"CORS_ORIGINS",
|
| 29 |
+
"http://localhost:8501,http://localhost:3000,http://localhost:8000"
|
| 30 |
+
)
|
| 31 |
+
_allowed_ws_origins = {o.strip().rstrip('/') for o in origins_str.split(",")}
|
| 32 |
+
|
| 33 |
+
# Add HuggingFace origins if deploying there
|
| 34 |
+
if os.getenv("SPACE_ID"):
|
| 35 |
+
_allowed_ws_origins.add("https://huggingface.co")
|
| 36 |
+
_allowed_ws_origins.add(f"https://{os.getenv('SPACE_ID')}.hf.space")
|
| 37 |
+
|
| 38 |
+
logger.info(f"WebSocket allowed origins: {_allowed_ws_origins}")
|
| 39 |
+
|
| 40 |
+
return _allowed_ws_origins
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def validate_ws_origin(websocket: WebSocket) -> bool:
|
| 44 |
+
"""
|
| 45 |
+
Validate WebSocket Origin header against allowed origins.
|
| 46 |
+
Returns True if valid, False otherwise.
|
| 47 |
+
"""
|
| 48 |
+
origin = websocket.headers.get("origin")
|
| 49 |
+
|
| 50 |
+
if not origin:
|
| 51 |
+
# No origin header - could be same-origin or non-browser client
|
| 52 |
+
# In production, you might want to reject these
|
| 53 |
+
logger.warning("WebSocket connection without Origin header")
|
| 54 |
+
return True # Allow for dev/non-browser clients
|
| 55 |
+
|
| 56 |
+
# Normalize origin
|
| 57 |
+
origin = origin.rstrip('/')
|
| 58 |
+
allowed = get_allowed_origins()
|
| 59 |
+
|
| 60 |
+
if origin in allowed:
|
| 61 |
+
return True
|
| 62 |
+
|
| 63 |
+
# Check for wildcard subdomain match (e.g., *.hf.space)
|
| 64 |
+
parsed = urlparse(origin)
|
| 65 |
+
if parsed.netloc.endswith('.hf.space'):
|
| 66 |
+
return True
|
| 67 |
+
|
| 68 |
+
logger.warning(f"WebSocket rejected: origin '{origin}' not in allowed list")
|
| 69 |
+
return False
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
async def authenticate_websocket(
|
| 73 |
+
websocket: WebSocket,
|
| 74 |
+
token: Optional[str] = None
|
| 75 |
+
) -> Optional[int]:
|
| 76 |
+
"""
|
| 77 |
+
Authenticate a WebSocket connection using JWT token.
|
| 78 |
+
|
| 79 |
+
Token can be provided via:
|
| 80 |
+
- Query parameter: ws://host/ws/endpoint?token=xxx
|
| 81 |
+
- First message after connection
|
| 82 |
+
|
| 83 |
+
Returns user_id if authenticated, None otherwise.
|
| 84 |
+
"""
|
| 85 |
+
# Try query param first
|
| 86 |
+
if not token:
|
| 87 |
+
token = websocket.query_params.get("token")
|
| 88 |
+
|
| 89 |
+
if not token:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
payload = jwt.decode(
|
| 94 |
+
token,
|
| 95 |
+
settings.secret_key,
|
| 96 |
+
algorithms=[settings.algorithm]
|
| 97 |
+
)
|
| 98 |
+
user_id = payload.get("sub")
|
| 99 |
+
if user_id:
|
| 100 |
+
return int(user_id)
|
| 101 |
+
except JWTError as e:
|
| 102 |
+
logger.warning(f"WebSocket auth failed: {e}")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.error(f"WebSocket auth error: {e}")
|
| 105 |
+
|
| 106 |
+
return None
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
async def require_ws_auth(websocket: WebSocket) -> int:
|
| 110 |
+
"""
|
| 111 |
+
Require authentication for WebSocket. Closes connection if not authenticated.
|
| 112 |
+
|
| 113 |
+
Usage:
|
| 114 |
+
@router.websocket("/secure/{client_id}")
|
| 115 |
+
async def secure_ws(websocket: WebSocket, client_id: str):
|
| 116 |
+
user_id = await require_ws_auth(websocket)
|
| 117 |
+
await websocket.accept()
|
| 118 |
+
# ... handle connection
|
| 119 |
+
"""
|
| 120 |
+
# Validate origin first
|
| 121 |
+
if not validate_ws_origin(websocket):
|
| 122 |
+
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
| 123 |
+
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
|
| 124 |
+
|
| 125 |
+
user_id = await authenticate_websocket(websocket)
|
| 126 |
+
if not user_id:
|
| 127 |
+
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
|
| 128 |
+
raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
|
| 129 |
+
|
| 130 |
+
return user_id
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class WebSocketRateLimiter:
|
| 134 |
+
"""
|
| 135 |
+
Simple rate limiter for WebSocket messages.
|
| 136 |
+
Tracks message count per connection and rejects if exceeded.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, max_messages_per_second: int = 10, max_message_size: int = 1024 * 1024):
|
| 140 |
+
self.max_rate = max_messages_per_second
|
| 141 |
+
self.max_size = max_message_size
|
| 142 |
+
self._counts: dict = {} # client_id -> (count, last_reset_time)
|
| 143 |
+
|
| 144 |
+
def check_rate(self, client_id: str) -> bool:
|
| 145 |
+
"""Check if client is within rate limits. Returns True if allowed."""
|
| 146 |
+
import time
|
| 147 |
+
now = time.time()
|
| 148 |
+
|
| 149 |
+
if client_id not in self._counts:
|
| 150 |
+
self._counts[client_id] = (1, now)
|
| 151 |
+
return True
|
| 152 |
+
|
| 153 |
+
count, last_reset = self._counts[client_id]
|
| 154 |
+
|
| 155 |
+
# Reset counter every second
|
| 156 |
+
if now - last_reset >= 1.0:
|
| 157 |
+
self._counts[client_id] = (1, now)
|
| 158 |
+
return True
|
| 159 |
+
|
| 160 |
+
# Check limit
|
| 161 |
+
if count >= self.max_rate:
|
| 162 |
+
logger.warning(f"WebSocket rate limit exceeded for {client_id}")
|
| 163 |
+
return False
|
| 164 |
+
|
| 165 |
+
self._counts[client_id] = (count + 1, last_reset)
|
| 166 |
+
return True
|
| 167 |
+
|
| 168 |
+
def check_size(self, data: bytes) -> bool:
|
| 169 |
+
"""Check if message size is within limits."""
|
| 170 |
+
if len(data) > self.max_size:
|
| 171 |
+
logger.warning(f"WebSocket message too large: {len(data)} > {self.max_size}")
|
| 172 |
+
return False
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
def cleanup(self, client_id: str):
|
| 176 |
+
"""Remove client from tracking."""
|
| 177 |
+
self._counts.pop(client_id, None)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
# Global rate limiter instance
|
| 181 |
+
ws_rate_limiter = WebSocketRateLimiter(max_messages_per_second=20, max_message_size=5 * 1024 * 1024)
|
backend/app/main.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge - FastAPI Main Application
|
| 3 |
+
Production-grade Speech-to-Text & Text-to-Speech API
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
# WARN: PyTorch 2.6+ security workaround for Pyannote
|
| 8 |
+
# Must be before any other torch imports
|
| 9 |
+
import os
|
| 10 |
+
os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
|
| 11 |
+
import torch.serialization
|
| 12 |
+
try:
|
| 13 |
+
torch.serialization.add_safe_globals([dict])
|
| 14 |
+
except:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
from contextlib import asynccontextmanager
|
| 18 |
+
from fastapi import FastAPI, Request
|
| 19 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
from fastapi.responses import JSONResponse
|
| 21 |
+
from fastapi.openapi.utils import get_openapi
|
| 22 |
+
|
| 23 |
+
from prometheus_fastapi_instrumentator import Instrumentator
|
| 24 |
+
from .core.config import get_settings
|
| 25 |
+
from .api.routes import (
|
| 26 |
+
stt_router,
|
| 27 |
+
tts_router,
|
| 28 |
+
health_router,
|
| 29 |
+
transcripts_router,
|
| 30 |
+
ws_router,
|
| 31 |
+
translation_router,
|
| 32 |
+
batch_router,
|
| 33 |
+
analysis_router,
|
| 34 |
+
audio_router,
|
| 35 |
+
cloning_router,
|
| 36 |
+
sign_router,
|
| 37 |
+
auth_router,
|
| 38 |
+
s2s_router,
|
| 39 |
+
sign_bridge # Import the module
|
| 40 |
+
)
|
| 41 |
+
from .models import Base, engine
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Configure logging
|
| 46 |
+
logging.basicConfig(
|
| 47 |
+
level=logging.INFO,
|
| 48 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 49 |
+
)
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
settings = get_settings()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@asynccontextmanager
|
| 56 |
+
async def lifespan(app: FastAPI):
|
| 57 |
+
"""
|
| 58 |
+
Application lifespan handler
|
| 59 |
+
Runs on startup and shutdown
|
| 60 |
+
"""
|
| 61 |
+
# Startup
|
| 62 |
+
logger.info(f"Starting {settings.app_name} v{settings.app_version}")
|
| 63 |
+
|
| 64 |
+
# Create database tables
|
| 65 |
+
logger.info("Creating database tables...")
|
| 66 |
+
Base.metadata.create_all(bind=engine)
|
| 67 |
+
|
| 68 |
+
# Pre-warm Whisper models for faster first request
|
| 69 |
+
logger.info("Pre-warming AI models...")
|
| 70 |
+
try:
|
| 71 |
+
from .services.whisper_stt_service import get_whisper_model
|
| 72 |
+
# Pre-load English Distil model (most common)
|
| 73 |
+
get_whisper_model("distil-small.en")
|
| 74 |
+
logger.info("✅ Distil-Whisper model loaded")
|
| 75 |
+
# Pre-load multilingual model
|
| 76 |
+
get_whisper_model("small")
|
| 77 |
+
logger.info("✅ Whisper-small model loaded")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.warning(f"Model pre-warming failed: {e}")
|
| 80 |
+
|
| 81 |
+
# Pre-cache TTS voice list
|
| 82 |
+
try:
|
| 83 |
+
from .services.tts_service import get_tts_service
|
| 84 |
+
tts_service = get_tts_service()
|
| 85 |
+
await tts_service.get_voices()
|
| 86 |
+
logger.info("✅ TTS voice list cached")
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.warning(f"Voice list caching failed: {e}")
|
| 89 |
+
|
| 90 |
+
logger.info("🚀 Startup complete - All models warmed up!")
|
| 91 |
+
|
| 92 |
+
yield
|
| 93 |
+
|
| 94 |
+
# Shutdown
|
| 95 |
+
logger.info("Shutting down...")
|
| 96 |
+
# TODO: Close database connections
|
| 97 |
+
# TODO: Close Redis connections
|
| 98 |
+
logger.info("Shutdown complete")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# Create FastAPI application
|
| 102 |
+
app = FastAPI(
|
| 103 |
+
title=settings.app_name,
|
| 104 |
+
description="""
|
| 105 |
+
## VoiceForge API
|
| 106 |
+
|
| 107 |
+
Production-grade Speech-to-Text and Text-to-Speech API.
|
| 108 |
+
|
| 109 |
+
### Features
|
| 110 |
+
|
| 111 |
+
- 🎤 **Speech-to-Text**: Transcribe audio files with word-level timestamps
|
| 112 |
+
- 🔊 **Text-to-Speech**: Synthesize speech with 300+ neural voices
|
| 113 |
+
- 🌍 **Multi-language**: Support for 10+ languages
|
| 114 |
+
- 🧠 **AI Analysis**: Sentiment, keywords, and summarization
|
| 115 |
+
- 🌐 **Translation**: Translate text/audio between 20+ languages
|
| 116 |
+
- ⚡ **Free & Fast**: Local Whisper + Edge TTS - no API costs
|
| 117 |
+
""",
|
| 118 |
+
version=settings.app_version,
|
| 119 |
+
docs_url="/docs",
|
| 120 |
+
redoc_url="/redoc",
|
| 121 |
+
lifespan=lifespan,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
from slowapi import _rate_limit_exceeded_handler
|
| 126 |
+
from slowapi.errors import RateLimitExceeded
|
| 127 |
+
from slowapi.middleware import SlowAPIMiddleware
|
| 128 |
+
from .core.limiter import limiter
|
| 129 |
+
from .core.security_headers import SecurityHeadersMiddleware
|
| 130 |
+
from .core.request_size_middleware import RequestSizeLimitMiddleware
|
| 131 |
+
|
| 132 |
+
# Request body size limit (must be first to reject large requests early)
|
| 133 |
+
app.add_middleware(RequestSizeLimitMiddleware)
|
| 134 |
+
|
| 135 |
+
# Add Rate Limiting (default: 60 requests/min per IP)
|
| 136 |
+
app.state.limiter = limiter
|
| 137 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 138 |
+
app.add_middleware(SlowAPIMiddleware)
|
| 139 |
+
|
| 140 |
+
# Security Headers (Must be before CORS to ensure headers are present even on errors/CORS blocks)
|
| 141 |
+
app.add_middleware(SecurityHeadersMiddleware)
|
| 142 |
+
|
| 143 |
+
# CORS middleware - Use CORS_ORIGINS env var (comma-separated) or defaults
|
| 144 |
+
_cors_origins = os.getenv(
|
| 145 |
+
"CORS_ORIGINS",
|
| 146 |
+
"http://localhost:8501,http://localhost:3000,http://localhost:8000"
|
| 147 |
+
).split(",")
|
| 148 |
+
# Add HuggingFace origins if deploying there
|
| 149 |
+
if os.getenv("SPACE_ID"): # HF Spaces sets this
|
| 150 |
+
_cors_origins.extend(["https://huggingface.co", f"https://{os.getenv('SPACE_ID')}.hf.space"])
|
| 151 |
+
|
| 152 |
+
app.add_middleware(
|
| 153 |
+
CORSMiddleware,
|
| 154 |
+
allow_origins=[o.strip() for o in _cors_origins],
|
| 155 |
+
allow_credentials=True,
|
| 156 |
+
allow_methods=["*"],
|
| 157 |
+
allow_headers=["*"],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Prometheus Metrics
|
| 161 |
+
Instrumentator().instrument(app).expose(app)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# Include routers
|
| 165 |
+
app.include_router(health_router)
|
| 166 |
+
app.include_router(auth_router, prefix="/api/v1")
|
| 167 |
+
app.include_router(stt_router, prefix="/api/v1")
|
| 168 |
+
app.include_router(tts_router, prefix="/api/v1")
|
| 169 |
+
app.include_router(transcripts_router, prefix="/api/v1")
|
| 170 |
+
app.include_router(ws_router, prefix="/api/v1")
|
| 171 |
+
app.include_router(translation_router, prefix="/api/v1")
|
| 172 |
+
app.include_router(batch_router, prefix="/api/v1")
|
| 173 |
+
app.include_router(analysis_router, prefix="/api/v1")
|
| 174 |
+
app.include_router(audio_router, prefix="/api/v1")
|
| 175 |
+
app.include_router(cloning_router, prefix="/api/v1")
|
| 176 |
+
app.include_router(sign_router, prefix="/api/v1")
|
| 177 |
+
app.include_router(s2s_router, prefix="/api/v1") # Added s2s_router
|
| 178 |
+
app.include_router(sign_bridge.router, prefix="/api/v1") # Added sign_bridge_router
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# Exception handlers
|
| 185 |
+
@app.exception_handler(Exception)
|
| 186 |
+
async def global_exception_handler(request: Request, exc: Exception):
|
| 187 |
+
"""Global exception handler for unhandled errors"""
|
| 188 |
+
logger.exception(f"Unhandled error: {exc}")
|
| 189 |
+
return JSONResponse(
|
| 190 |
+
status_code=500,
|
| 191 |
+
content={
|
| 192 |
+
"error": "internal_server_error",
|
| 193 |
+
"message": "An unexpected error occurred",
|
| 194 |
+
"detail": str(exc) if settings.debug else None,
|
| 195 |
+
},
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@app.exception_handler(ValueError)
|
| 200 |
+
async def value_error_handler(request: Request, exc: ValueError):
|
| 201 |
+
"""Handler for validation errors"""
|
| 202 |
+
return JSONResponse(
|
| 203 |
+
status_code=400,
|
| 204 |
+
content={
|
| 205 |
+
"error": "validation_error",
|
| 206 |
+
"message": str(exc),
|
| 207 |
+
},
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# Root endpoint
|
| 212 |
+
@app.get("/", tags=["Root"])
|
| 213 |
+
async def root():
|
| 214 |
+
"""API root - returns basic info"""
|
| 215 |
+
return {
|
| 216 |
+
"name": settings.app_name,
|
| 217 |
+
"version": settings.app_version,
|
| 218 |
+
"status": "running",
|
| 219 |
+
"docs": "/docs",
|
| 220 |
+
"health": "/health",
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Custom OpenAPI schema
|
| 225 |
+
def custom_openapi():
|
| 226 |
+
"""Generate custom OpenAPI schema with enhanced documentation"""
|
| 227 |
+
if app.openapi_schema:
|
| 228 |
+
return app.openapi_schema
|
| 229 |
+
|
| 230 |
+
openapi_schema = get_openapi(
|
| 231 |
+
title=settings.app_name,
|
| 232 |
+
version=settings.app_version,
|
| 233 |
+
description=app.description,
|
| 234 |
+
routes=app.routes,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Add custom logo
|
| 238 |
+
openapi_schema["info"]["x-logo"] = {
|
| 239 |
+
"url": "https://example.com/logo.png"
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
# Add tags with descriptions
|
| 243 |
+
openapi_schema["tags"] = [
|
| 244 |
+
{
|
| 245 |
+
"name": "Health",
|
| 246 |
+
"description": "Health check endpoints for monitoring",
|
| 247 |
+
},
|
| 248 |
+
{
|
| 249 |
+
"name": "Speech-to-Text",
|
| 250 |
+
"description": "Convert audio to text with timestamps and speaker detection",
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"name": "Text-to-Speech",
|
| 254 |
+
"description": "Convert text to natural-sounding speech",
|
| 255 |
+
},
|
| 256 |
+
]
|
| 257 |
+
|
| 258 |
+
app.openapi_schema = openapi_schema
|
| 259 |
+
return app.openapi_schema
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
app.openapi = custom_openapi
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
if __name__ == "__main__":
|
| 266 |
+
import uvicorn
|
| 267 |
+
|
| 268 |
+
uvicorn.run(
|
| 269 |
+
"app.main:app",
|
| 270 |
+
host=settings.api_host,
|
| 271 |
+
port=settings.api_port,
|
| 272 |
+
reload=settings.debug,
|
| 273 |
+
)
|
backend/app/models/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge Database Models Package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .base import Base, engine, SessionLocal, get_db
|
| 6 |
+
from .audio_file import AudioFile
|
| 7 |
+
from .transcript import Transcript
|
| 8 |
+
from .auth import User, ApiKey
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"Base",
|
| 12 |
+
"engine",
|
| 13 |
+
"SessionLocal",
|
| 14 |
+
"get_db",
|
| 15 |
+
"AudioFile",
|
| 16 |
+
"Transcript",
|
| 17 |
+
"User",
|
| 18 |
+
"ApiKey",
|
| 19 |
+
]
|
backend/app/models/audio_file.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio File Model
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, Enum
|
| 7 |
+
from sqlalchemy.orm import relationship
|
| 8 |
+
import enum
|
| 9 |
+
|
| 10 |
+
from .base import Base
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AudioFileStatus(str, enum.Enum):
|
| 14 |
+
"""Audio file processing status"""
|
| 15 |
+
UPLOADED = "uploaded"
|
| 16 |
+
PROCESSING = "processing"
|
| 17 |
+
DONE = "done"
|
| 18 |
+
FAILED = "failed"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class AudioFile(Base):
|
| 22 |
+
"""Audio file database model"""
|
| 23 |
+
|
| 24 |
+
__tablename__ = "audio_files"
|
| 25 |
+
|
| 26 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 27 |
+
# user_id removed
|
| 28 |
+
storage_path = Column(String(500), nullable=False)
|
| 29 |
+
original_filename = Column(String(255), nullable=True)
|
| 30 |
+
duration = Column(Float, nullable=True) # Duration in seconds
|
| 31 |
+
format = Column(String(20), nullable=True) # wav, mp3, etc.
|
| 32 |
+
sample_rate = Column(Integer, nullable=True)
|
| 33 |
+
channels = Column(Integer, nullable=True)
|
| 34 |
+
file_size = Column(Integer, nullable=True) # Size in bytes
|
| 35 |
+
language = Column(String(10), nullable=True) # User-specified language
|
| 36 |
+
detected_language = Column(String(10), nullable=True) # Auto-detected language
|
| 37 |
+
status = Column(String(20), default=AudioFileStatus.UPLOADED.value, index=True)
|
| 38 |
+
error_message = Column(String(500), nullable=True)
|
| 39 |
+
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
| 40 |
+
processed_at = Column(DateTime, nullable=True)
|
| 41 |
+
|
| 42 |
+
# Relationships
|
| 43 |
+
# user relationship removed
|
| 44 |
+
transcripts = relationship("Transcript", back_populates="audio_file")
|
| 45 |
+
|
| 46 |
+
def __repr__(self):
|
| 47 |
+
return f"<AudioFile(id={self.id}, filename={self.original_filename}, status={self.status})>"
|
backend/app/models/auth.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
User and API Key Models
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, DateTime
|
| 6 |
+
from sqlalchemy.orm import relationship
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from .base import Base
|
| 9 |
+
from ..core.security_encryption import EncryptedString
|
| 10 |
+
|
| 11 |
+
class User(Base):
|
| 12 |
+
__tablename__ = "users"
|
| 13 |
+
|
| 14 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 15 |
+
email = Column(String, unique=True, index=True, nullable=False) # Not encrypted (needed for lookup)
|
| 16 |
+
hashed_password = Column(String, nullable=False)
|
| 17 |
+
full_name = Column(EncryptedString(255), nullable=True) # ENCRYPTED
|
| 18 |
+
is_active = Column(Boolean, default=True)
|
| 19 |
+
is_superuser = Column(Boolean, default=False)
|
| 20 |
+
|
| 21 |
+
# Relationships
|
| 22 |
+
api_keys = relationship("ApiKey", back_populates="user", cascade="all, delete-orphan")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ApiKey(Base):
|
| 26 |
+
__tablename__ = "api_keys"
|
| 27 |
+
|
| 28 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 29 |
+
key = Column(String, unique=True, index=True, nullable=False)
|
| 30 |
+
name = Column(String, nullable=True) # e.g. "Production App"
|
| 31 |
+
is_active = Column(Boolean, default=True)
|
| 32 |
+
created_at = Column(DateTime, default=datetime.utcnow)
|
| 33 |
+
last_used_at = Column(DateTime, nullable=True)
|
| 34 |
+
|
| 35 |
+
user_id = Column(Integer, ForeignKey("users.id"))
|
| 36 |
+
user = relationship("User", back_populates="api_keys")
|
backend/app/models/base.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SQLAlchemy Base and Database Session
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from sqlalchemy import create_engine
|
| 6 |
+
from sqlalchemy.ext.declarative import declarative_base
|
| 7 |
+
from sqlalchemy.orm import sessionmaker
|
| 8 |
+
|
| 9 |
+
from ..core.config import get_settings
|
| 10 |
+
|
| 11 |
+
settings = get_settings()
|
| 12 |
+
|
| 13 |
+
# Create SQLAlchemy engine
|
| 14 |
+
# Create SQLAlchemy engine
|
| 15 |
+
if "sqlite" in settings.database_url:
|
| 16 |
+
engine = create_engine(
|
| 17 |
+
settings.database_url,
|
| 18 |
+
connect_args={"check_same_thread": False},
|
| 19 |
+
)
|
| 20 |
+
else:
|
| 21 |
+
engine = create_engine(
|
| 22 |
+
settings.database_url,
|
| 23 |
+
pool_pre_ping=True,
|
| 24 |
+
pool_size=10,
|
| 25 |
+
max_overflow=20,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Create session factory
|
| 29 |
+
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
| 30 |
+
|
| 31 |
+
# Create declarative base
|
| 32 |
+
Base = declarative_base()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_db():
|
| 36 |
+
"""
|
| 37 |
+
Database session dependency for FastAPI
|
| 38 |
+
Yields a database session and ensures cleanup
|
| 39 |
+
"""
|
| 40 |
+
db = SessionLocal()
|
| 41 |
+
try:
|
| 42 |
+
yield db
|
| 43 |
+
finally:
|
| 44 |
+
db.close()
|
backend/app/models/sign_lstm.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
class SignLSTM(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
LSTM Model for Sign Language Recognition.
|
| 8 |
+
|
| 9 |
+
Architecture:
|
| 10 |
+
- Input: Sequence of MediaPipe Landmarks (21 points * 3 coords = 63 features)
|
| 11 |
+
- Hidden Layers: 2 or 3 LSTM layers to capture temporal dynamics
|
| 12 |
+
- Output: Fully Connected layer -> Class probabilities (ASL Alphabet or Vocabulary)
|
| 13 |
+
|
| 14 |
+
Why LSTM?
|
| 15 |
+
It captures the 'motion' and 'context' of signs, not just static hand shapes,
|
| 16 |
+
allowing for much higher accuracy (99% reported in research) and dynamic gesture recognition.
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, input_size=63, hidden_size=128, num_layers=2, num_classes=26):
|
| 19 |
+
super(SignLSTM, self).__init__()
|
| 20 |
+
|
| 21 |
+
self.hidden_size = hidden_size
|
| 22 |
+
self.num_layers = num_layers
|
| 23 |
+
|
| 24 |
+
# LSTM Layer
|
| 25 |
+
# batch_first=True expects input shape: (batch, seq_len, features)
|
| 26 |
+
self.lstm = nn.LSTM(
|
| 27 |
+
input_size,
|
| 28 |
+
hidden_size,
|
| 29 |
+
num_layers,
|
| 30 |
+
batch_first=True,
|
| 31 |
+
dropout=0.2
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Fully Connected Layer for classification
|
| 35 |
+
self.fc = nn.Linear(hidden_size, num_classes)
|
| 36 |
+
|
| 37 |
+
# Validation/Inference activation
|
| 38 |
+
self.softmax = nn.Softmax(dim=1)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
# x shape: (batch_size, sequence_length, input_size)
|
| 42 |
+
|
| 43 |
+
# Initialize hidden state and cell state (optional, defaults to zeros if not provided)
|
| 44 |
+
# h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 45 |
+
# c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
|
| 46 |
+
|
| 47 |
+
# Forward propagate LSTM
|
| 48 |
+
# out shape: (batch_size, seq_len, hidden_size)
|
| 49 |
+
out, _ = self.lstm(x)
|
| 50 |
+
|
| 51 |
+
# Decode the hidden state of the last time step
|
| 52 |
+
# out[:, -1, :] gets the last output of the sequence
|
| 53 |
+
out = self.fc(out[:, -1, :])
|
| 54 |
+
|
| 55 |
+
return out
|
| 56 |
+
|
| 57 |
+
def predict(self, x):
|
| 58 |
+
"""Helper for inference"""
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
logits = self.forward(x)
|
| 61 |
+
probabilities = self.softmax(logits)
|
| 62 |
+
confidence, predicted_class = torch.max(probabilities, 1)
|
| 63 |
+
return predicted_class.item(), confidence.item()
|
backend/app/models/transcript.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transcript Model
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON, Float
|
| 7 |
+
from sqlalchemy.orm import relationship
|
| 8 |
+
|
| 9 |
+
from .base import Base
|
| 10 |
+
from ..core.security_encryption import EncryptedString
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Transcript(Base):
|
| 14 |
+
"""Transcript database model"""
|
| 15 |
+
|
| 16 |
+
__tablename__ = "transcripts"
|
| 17 |
+
|
| 18 |
+
id = Column(Integer, primary_key=True, index=True)
|
| 19 |
+
audio_file_id = Column(Integer, ForeignKey("audio_files.id"), nullable=True, index=True)
|
| 20 |
+
audio_file_id = Column(Integer, ForeignKey("audio_files.id"), nullable=True, index=True)
|
| 21 |
+
# user_id removed (Auth disabled for portfolio)
|
| 22 |
+
|
| 23 |
+
# Transcript content - ENCRYPTED
|
| 24 |
+
raw_text = Column(EncryptedString(10000), nullable=True) # Original transcription
|
| 25 |
+
processed_text = Column(EncryptedString(10000), nullable=True) # After NLP processing
|
| 26 |
+
|
| 27 |
+
# Segments with timestamps and speaker info (JSON array)
|
| 28 |
+
# Format: [{"start": 0.0, "end": 1.5, "text": "Hello", "speaker": "SPEAKER_1", "confidence": 0.95}]
|
| 29 |
+
segments = Column(JSON, nullable=True)
|
| 30 |
+
|
| 31 |
+
# Word-level timestamps (JSON array)
|
| 32 |
+
# Format: [{"word": "hello", "start": 0.0, "end": 0.5, "confidence": 0.98}]
|
| 33 |
+
words = Column(JSON, nullable=True)
|
| 34 |
+
|
| 35 |
+
# Language info
|
| 36 |
+
language = Column(String(10), nullable=True) # Transcription language
|
| 37 |
+
translation_language = Column(String(10), nullable=True) # If translated
|
| 38 |
+
translated_text = Column(Text, nullable=True)
|
| 39 |
+
|
| 40 |
+
# NLP Analysis (Phase 2)
|
| 41 |
+
sentiment = Column(JSON, nullable=True) # {"overall": "positive", "score": 0.8, "segments": [...]}
|
| 42 |
+
topics = Column(JSON, nullable=True) # ["technology", "business"]
|
| 43 |
+
keywords = Column(JSON, nullable=True) # [{"word": "AI", "score": 0.9}]
|
| 44 |
+
action_items = Column(JSON, nullable=True) # [{"text": "Email John", "assignee": "Speaker 1"}]
|
| 45 |
+
attendees = Column(JSON, nullable=True) # ["Speaker 1", "Speaker 2"]
|
| 46 |
+
summary = Column(EncryptedString(5000), nullable=True) # ENCRYPTED
|
| 47 |
+
|
| 48 |
+
# Metadata
|
| 49 |
+
confidence = Column(Float, nullable=True) # Overall confidence score
|
| 50 |
+
duration = Column(Float, nullable=True) # Audio duration in seconds
|
| 51 |
+
word_count = Column(Integer, nullable=True)
|
| 52 |
+
|
| 53 |
+
# Timestamps
|
| 54 |
+
created_at = Column(DateTime, default=datetime.utcnow, index=True)
|
| 55 |
+
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
| 56 |
+
|
| 57 |
+
# Relationships
|
| 58 |
+
audio_file = relationship("AudioFile", back_populates="transcripts")
|
| 59 |
+
audio_file = relationship("AudioFile", back_populates="transcripts")
|
| 60 |
+
# user relationship removed
|
| 61 |
+
|
| 62 |
+
def __repr__(self):
|
| 63 |
+
preview = self.raw_text[:50] + "..." if self.raw_text and len(self.raw_text) > 50 else self.raw_text
|
| 64 |
+
return f"<Transcript(id={self.id}, preview='{preview}')>"
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Import Float for confidence field
|
backend/app/schemas/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge Schemas Package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .stt import (
|
| 6 |
+
TranscriptionRequest,
|
| 7 |
+
TranscriptionResponse,
|
| 8 |
+
TranscriptionSegment,
|
| 9 |
+
TranscriptionWord,
|
| 10 |
+
LanguageInfo,
|
| 11 |
+
)
|
| 12 |
+
from .tts import (
|
| 13 |
+
SynthesisRequest,
|
| 14 |
+
SynthesisResponse,
|
| 15 |
+
VoiceInfo,
|
| 16 |
+
VoiceListResponse,
|
| 17 |
+
)
|
| 18 |
+
from .transcript import (
|
| 19 |
+
TranscriptCreate,
|
| 20 |
+
TranscriptUpdate,
|
| 21 |
+
TranscriptResponse,
|
| 22 |
+
TranscriptListResponse,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
"TranscriptionRequest",
|
| 27 |
+
"TranscriptionResponse",
|
| 28 |
+
"TranscriptionSegment",
|
| 29 |
+
"TranscriptionWord",
|
| 30 |
+
"LanguageInfo",
|
| 31 |
+
"SynthesisRequest",
|
| 32 |
+
"SynthesisResponse",
|
| 33 |
+
"VoiceInfo",
|
| 34 |
+
"VoiceListResponse",
|
| 35 |
+
"TranscriptCreate",
|
| 36 |
+
"TranscriptUpdate",
|
| 37 |
+
"TranscriptResponse",
|
| 38 |
+
"TranscriptListResponse",
|
| 39 |
+
]
|
backend/app/schemas/stt.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Speech-to-Text Schemas
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TranscriptionWord(BaseModel):
|
| 11 |
+
"""Individual word with timing information"""
|
| 12 |
+
word: str
|
| 13 |
+
start_time: float = Field(..., description="Start time in seconds")
|
| 14 |
+
end_time: float = Field(..., description="End time in seconds")
|
| 15 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TranscriptionSegment(BaseModel):
|
| 19 |
+
"""Transcript segment with speaker and timing"""
|
| 20 |
+
text: str
|
| 21 |
+
start_time: float = Field(..., description="Start time in seconds")
|
| 22 |
+
end_time: float = Field(..., description="End time in seconds")
|
| 23 |
+
speaker: Optional[str] = Field(None, description="Speaker label (e.g., SPEAKER_1)")
|
| 24 |
+
confidence: float = Field(..., ge=0.0, le=1.0)
|
| 25 |
+
words: Optional[List[TranscriptionWord]] = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TranscriptionRequest(BaseModel):
|
| 29 |
+
"""Request parameters for transcription"""
|
| 30 |
+
language: str = Field(default="en-US", description="Language code (e.g., en-US)")
|
| 31 |
+
enable_automatic_punctuation: bool = True
|
| 32 |
+
enable_word_time_offsets: bool = True
|
| 33 |
+
enable_speaker_diarization: bool = False
|
| 34 |
+
diarization_speaker_count: Optional[int] = Field(None, ge=2, le=10)
|
| 35 |
+
model: str = Field(default="default", description="STT model to use")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class TranscriptionResponse(BaseModel):
|
| 39 |
+
"""Response from transcription"""
|
| 40 |
+
id: Optional[int] = None
|
| 41 |
+
audio_file_id: Optional[int] = None
|
| 42 |
+
text: str = Field(..., description="Full transcription text")
|
| 43 |
+
segments: List[TranscriptionSegment] = Field(default_factory=list)
|
| 44 |
+
words: Optional[List[TranscriptionWord]] = None
|
| 45 |
+
language: str
|
| 46 |
+
detected_language: Optional[str] = None
|
| 47 |
+
confidence: float = Field(..., ge=0.0, le=1.0)
|
| 48 |
+
duration: float = Field(..., description="Audio duration in seconds")
|
| 49 |
+
word_count: int
|
| 50 |
+
processing_time: float = Field(..., description="Processing time in seconds")
|
| 51 |
+
|
| 52 |
+
model_config = {
|
| 53 |
+
"from_attributes": True
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class StreamingTranscriptionResponse(BaseModel):
|
| 58 |
+
"""Response for streaming transcription updates"""
|
| 59 |
+
is_final: bool = False
|
| 60 |
+
text: str
|
| 61 |
+
confidence: float = Field(default=0.0, ge=0.0, le=1.0)
|
| 62 |
+
stability: float = Field(default=0.0, ge=0.0, le=1.0)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class LanguageInfo(BaseModel):
|
| 66 |
+
"""Language information for UI display"""
|
| 67 |
+
code: str = Field(..., description="Language code (e.g., en-US)")
|
| 68 |
+
name: str = Field(..., description="Display name (e.g., English (US))")
|
| 69 |
+
native_name: str = Field(..., description="Native name (e.g., English)")
|
| 70 |
+
flag: str = Field(..., description="Flag emoji")
|
| 71 |
+
stt_supported: bool = True
|
| 72 |
+
tts_supported: bool = True
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class LanguageListResponse(BaseModel):
|
| 76 |
+
"""Response with list of supported languages"""
|
| 77 |
+
languages: List[LanguageInfo]
|
| 78 |
+
total: int
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TaskStatusResponse(BaseModel):
|
| 83 |
+
"""Status of an async transcription task"""
|
| 84 |
+
task_id: str
|
| 85 |
+
status: str = Field(..., description="pending, processing, completed, failed")
|
| 86 |
+
progress: float = Field(default=0.0, ge=0.0, le=100.0, description="Progress percentage")
|
| 87 |
+
result: Optional[TranscriptionResponse] = None
|
| 88 |
+
error: Optional[str] = None
|
| 89 |
+
created_at: datetime
|
| 90 |
+
updated_at: datetime
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class AsyncTranscriptionResponse(BaseModel):
|
| 94 |
+
"""Response for async transcription submission"""
|
| 95 |
+
task_id: str
|
| 96 |
+
audio_file_id: int
|
| 97 |
+
status: str = "queued"
|
| 98 |
+
message: str = "File uploaded and queued for processing"
|
backend/app/schemas/transcript.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transcript Schemas
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from typing import List, Optional, Dict, Any
|
| 7 |
+
from pydantic import BaseModel, Field
|
| 8 |
+
|
| 9 |
+
from .stt import TranscriptionSegment, TranscriptionWord
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TranscriptCreate(BaseModel):
|
| 13 |
+
"""Schema for creating a transcript"""
|
| 14 |
+
raw_text: str
|
| 15 |
+
processed_text: Optional[str] = None
|
| 16 |
+
segments: Optional[List[Dict[str, Any]]] = None
|
| 17 |
+
words: Optional[List[Dict[str, Any]]] = None
|
| 18 |
+
language: str = "en-US"
|
| 19 |
+
confidence: Optional[float] = None
|
| 20 |
+
duration: Optional[float] = None
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TranscriptUpdate(BaseModel):
|
| 24 |
+
"""Schema for updating a transcript"""
|
| 25 |
+
processed_text: Optional[str] = None
|
| 26 |
+
language: Optional[str] = None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TranscriptResponse(BaseModel):
|
| 30 |
+
"""Schema for transcript response"""
|
| 31 |
+
id: int
|
| 32 |
+
audio_file_id: Optional[int] = None
|
| 33 |
+
user_id: Optional[int] = None
|
| 34 |
+
raw_text: Optional[str] = None
|
| 35 |
+
processed_text: Optional[str] = None
|
| 36 |
+
segments: Optional[List[Dict[str, Any]]] = None
|
| 37 |
+
words: Optional[List[Dict[str, Any]]] = None
|
| 38 |
+
language: Optional[str] = None
|
| 39 |
+
translation_language: Optional[str] = None
|
| 40 |
+
translated_text: Optional[str] = None
|
| 41 |
+
sentiment: Optional[Dict[str, Any]] = None
|
| 42 |
+
topics: Optional[List[str]] = None
|
| 43 |
+
keywords: Optional[List[Dict[str, Any]]] = None
|
| 44 |
+
summary: Optional[str] = None
|
| 45 |
+
confidence: Optional[float] = None
|
| 46 |
+
duration: Optional[float] = None
|
| 47 |
+
word_count: Optional[int] = None
|
| 48 |
+
created_at: datetime
|
| 49 |
+
updated_at: Optional[datetime] = None
|
| 50 |
+
|
| 51 |
+
model_config = {
|
| 52 |
+
"from_attributes": True
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TranscriptListResponse(BaseModel):
|
| 57 |
+
"""Schema for paginated transcript list"""
|
| 58 |
+
transcripts: List[TranscriptResponse]
|
| 59 |
+
total: int
|
| 60 |
+
page: int
|
| 61 |
+
page_size: int
|
| 62 |
+
has_more: bool
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ExportRequest(BaseModel):
|
| 66 |
+
"""Schema for transcript export request"""
|
| 67 |
+
format: str = Field(..., pattern="^(txt|srt|vtt|pdf|json)$")
|
| 68 |
+
include_timestamps: bool = True
|
| 69 |
+
include_speakers: bool = True
|
backend/app/schemas/tts.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Text-to-Speech Schemas
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SynthesisRequest(BaseModel):
|
| 10 |
+
"""Request for text-to-speech synthesis"""
|
| 11 |
+
text: str = Field(..., min_length=1, max_length=5000, description="Text to synthesize")
|
| 12 |
+
language: str = Field(default="en-US", description="Language code")
|
| 13 |
+
voice: Optional[str] = Field(None, description="Voice name (e.g., en-US-Wavenet-D)")
|
| 14 |
+
|
| 15 |
+
# Audio configuration
|
| 16 |
+
audio_encoding: str = Field(default="MP3", description="Output format: MP3, LINEAR16, OGG_OPUS")
|
| 17 |
+
sample_rate: int = Field(default=24000, description="Sample rate in Hz")
|
| 18 |
+
|
| 19 |
+
# Voice tuning
|
| 20 |
+
speaking_rate: float = Field(default=1.0, ge=0.25, le=4.0, description="Speaking rate")
|
| 21 |
+
pitch: float = Field(default=0.0, ge=-20.0, le=20.0, description="Voice pitch in semitones")
|
| 22 |
+
volume_gain_db: float = Field(default=0.0, ge=-96.0, le=16.0, description="Volume gain in dB")
|
| 23 |
+
|
| 24 |
+
# SSML support
|
| 25 |
+
use_ssml: bool = Field(default=False, description="Treat text as SSML")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SynthesisResponse(BaseModel):
|
| 29 |
+
"""Response from text-to-speech synthesis"""
|
| 30 |
+
audio_content: str = Field(..., description="Base64 encoded audio")
|
| 31 |
+
audio_size: int = Field(..., description="Audio size in bytes")
|
| 32 |
+
duration_estimate: float = Field(..., description="Estimated duration in seconds")
|
| 33 |
+
voice_used: str
|
| 34 |
+
language: str
|
| 35 |
+
encoding: str
|
| 36 |
+
sample_rate: int
|
| 37 |
+
processing_time: float = Field(..., description="Processing time in seconds")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class VoiceInfo(BaseModel):
|
| 41 |
+
"""Information about a TTS voice"""
|
| 42 |
+
name: str = Field(..., description="Voice name (e.g., en-US-Wavenet-D)")
|
| 43 |
+
language_code: str = Field(..., description="Language code")
|
| 44 |
+
language_name: str = Field(..., description="Language display name")
|
| 45 |
+
ssml_gender: str = Field(..., description="MALE, FEMALE, or NEUTRAL")
|
| 46 |
+
natural_sample_rate: int = Field(..., description="Native sample rate in Hz")
|
| 47 |
+
voice_type: str = Field(..., description="Standard, WaveNet, or Neural2")
|
| 48 |
+
|
| 49 |
+
# Display helpers
|
| 50 |
+
display_name: Optional[str] = None
|
| 51 |
+
flag: Optional[str] = None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class VoiceListResponse(BaseModel):
|
| 55 |
+
"""Response with list of available voices"""
|
| 56 |
+
voices: List[VoiceInfo]
|
| 57 |
+
total: int
|
| 58 |
+
language_filter: Optional[str] = None
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class VoicePreviewRequest(BaseModel):
|
| 62 |
+
"""Request for voice preview"""
|
| 63 |
+
voice: str = Field(..., description="Voice name to preview")
|
| 64 |
+
text: Optional[str] = Field(
|
| 65 |
+
default="Hello! This is a preview of my voice.",
|
| 66 |
+
max_length=200
|
| 67 |
+
)
|
backend/app/services/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
VoiceForge Services Package
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .stt_service import STTService
|
| 6 |
+
from .tts_service import TTSService
|
| 7 |
+
from .file_service import FileService
|
| 8 |
+
|
| 9 |
+
__all__ = [
|
| 10 |
+
"STTService",
|
| 11 |
+
"TTSService",
|
| 12 |
+
"FileService",
|
| 13 |
+
]
|
backend/app/services/audio_service.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio Editing Service
|
| 3 |
+
Handles audio manipulation: Trimming, Merging, and Conversion using Pydub/FFmpeg
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
from typing import List, Optional
|
| 9 |
+
from pydub import AudioSegment
|
| 10 |
+
import tempfile
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class AudioService:
|
| 15 |
+
"""
|
| 16 |
+
Service for audio manipulation tasks.
|
| 17 |
+
Requires ffmpeg to be installed/available in path.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
pass
|
| 22 |
+
|
| 23 |
+
def load_audio(self, file_path: str) -> AudioSegment:
|
| 24 |
+
"""Load audio file into Pydub AudioSegment"""
|
| 25 |
+
try:
|
| 26 |
+
return AudioSegment.from_file(file_path)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
logger.error(f"Failed to load audio {file_path}: {e}")
|
| 29 |
+
raise ValueError(f"Could not load audio file: {str(e)}")
|
| 30 |
+
|
| 31 |
+
def trim_audio(self, input_path: str, start_ms: int, end_ms: int, output_path: Optional[str] = None) -> str:
|
| 32 |
+
"""
|
| 33 |
+
Trim audio from start_ms to end_ms.
|
| 34 |
+
"""
|
| 35 |
+
if start_ms < 0 or end_ms <= start_ms:
|
| 36 |
+
raise ValueError("Invalid start/end timestamps")
|
| 37 |
+
|
| 38 |
+
audio = self.load_audio(input_path)
|
| 39 |
+
|
| 40 |
+
# Check duration
|
| 41 |
+
if start_ms >= len(audio):
|
| 42 |
+
raise ValueError("Start time exceeds audio duration")
|
| 43 |
+
|
| 44 |
+
# Slice
|
| 45 |
+
trimmed = audio[start_ms:end_ms]
|
| 46 |
+
|
| 47 |
+
if not output_path:
|
| 48 |
+
base, ext = os.path.splitext(input_path)
|
| 49 |
+
output_path = f"{base}_trimmed{ext}"
|
| 50 |
+
|
| 51 |
+
trimmed.export(output_path, format=os.path.splitext(output_path)[1][1:])
|
| 52 |
+
logger.info(f"Trimmed audio saved to {output_path}")
|
| 53 |
+
return output_path
|
| 54 |
+
|
| 55 |
+
def merge_audio(self, file_paths: List[str], output_path: str, crossfade_ms: int = 0) -> str:
|
| 56 |
+
"""
|
| 57 |
+
Merge multiple audio files into one.
|
| 58 |
+
"""
|
| 59 |
+
if not file_paths:
|
| 60 |
+
raise ValueError("No files to merge")
|
| 61 |
+
|
| 62 |
+
combined = AudioSegment.empty()
|
| 63 |
+
|
| 64 |
+
for path in file_paths:
|
| 65 |
+
segment = self.load_audio(path)
|
| 66 |
+
if crossfade_ms > 0 and len(combined) > 0:
|
| 67 |
+
combined = combined.append(segment, crossfade=crossfade_ms)
|
| 68 |
+
else:
|
| 69 |
+
combined += segment
|
| 70 |
+
|
| 71 |
+
# Create dir if needed
|
| 72 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 73 |
+
|
| 74 |
+
# Export
|
| 75 |
+
fmt = os.path.splitext(output_path)[1][1:] or "mp3"
|
| 76 |
+
combined.export(output_path, format=fmt)
|
| 77 |
+
logger.info(f"Merged {len(file_paths)} files to {output_path}")
|
| 78 |
+
return output_path
|
| 79 |
+
|
| 80 |
+
def convert_format(self, input_path: str, target_format: str) -> str:
|
| 81 |
+
"""
|
| 82 |
+
Convert audio format (e.g. wav -> mp3)
|
| 83 |
+
"""
|
| 84 |
+
audio = self.load_audio(input_path)
|
| 85 |
+
|
| 86 |
+
base = os.path.splitext(input_path)[0]
|
| 87 |
+
output_path = f"{base}.{target_format}"
|
| 88 |
+
|
| 89 |
+
audio.export(output_path, format=target_format)
|
| 90 |
+
logger.info(f"Converted to {target_format}: {output_path}")
|
| 91 |
+
return output_path
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Singleton
|
| 95 |
+
_audio_service = None
|
| 96 |
+
|
| 97 |
+
def get_audio_service() -> AudioService:
|
| 98 |
+
global _audio_service
|
| 99 |
+
if _audio_service is None:
|
| 100 |
+
_audio_service = AudioService()
|
| 101 |
+
return _audio_service
|
backend/app/services/batch_service.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Batch Processing Service
|
| 3 |
+
Handles multi-file transcription with job tracking and parallel processing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
import tempfile
|
| 10 |
+
import uuid
|
| 11 |
+
import zipfile
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List, Optional, Any
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from enum import Enum
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class JobStatus(str, Enum):
|
| 22 |
+
"""Batch job status enum."""
|
| 23 |
+
PENDING = "pending"
|
| 24 |
+
PROCESSING = "processing"
|
| 25 |
+
COMPLETED = "completed"
|
| 26 |
+
FAILED = "failed"
|
| 27 |
+
CANCELLED = "cancelled"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class FileStatus(str, Enum):
|
| 31 |
+
"""Individual file status."""
|
| 32 |
+
QUEUED = "queued"
|
| 33 |
+
PROCESSING = "processing"
|
| 34 |
+
COMPLETED = "completed"
|
| 35 |
+
FAILED = "failed"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class FileResult:
|
| 40 |
+
"""Result for a single file in batch."""
|
| 41 |
+
filename: str
|
| 42 |
+
status: FileStatus = FileStatus.QUEUED
|
| 43 |
+
progress: float = 0.0
|
| 44 |
+
transcript: Optional[str] = None
|
| 45 |
+
language: Optional[str] = None
|
| 46 |
+
duration: Optional[float] = None
|
| 47 |
+
word_count: Optional[int] = None
|
| 48 |
+
processing_time: Optional[float] = None
|
| 49 |
+
error: Optional[str] = None
|
| 50 |
+
output_path: Optional[str] = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class BatchJob:
|
| 55 |
+
"""Batch processing job."""
|
| 56 |
+
job_id: str
|
| 57 |
+
status: JobStatus = JobStatus.PENDING
|
| 58 |
+
created_at: datetime = field(default_factory=datetime.now)
|
| 59 |
+
started_at: Optional[datetime] = None
|
| 60 |
+
completed_at: Optional[datetime] = None
|
| 61 |
+
files: Dict[str, FileResult] = field(default_factory=dict)
|
| 62 |
+
total_files: int = 0
|
| 63 |
+
completed_files: int = 0
|
| 64 |
+
failed_files: int = 0
|
| 65 |
+
options: Dict[str, Any] = field(default_factory=dict)
|
| 66 |
+
output_zip_path: Optional[str] = None
|
| 67 |
+
|
| 68 |
+
@property
|
| 69 |
+
def progress(self) -> float:
|
| 70 |
+
"""Overall job progress percentage."""
|
| 71 |
+
if self.total_files == 0:
|
| 72 |
+
return 0.0
|
| 73 |
+
return (self.completed_files + self.failed_files) / self.total_files * 100
|
| 74 |
+
|
| 75 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 76 |
+
"""Convert to dictionary for API response."""
|
| 77 |
+
return {
|
| 78 |
+
"job_id": self.job_id,
|
| 79 |
+
"status": self.status.value,
|
| 80 |
+
"progress": round(self.progress, 1),
|
| 81 |
+
"created_at": self.created_at.isoformat(),
|
| 82 |
+
"started_at": self.started_at.isoformat() if self.started_at else None,
|
| 83 |
+
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
| 84 |
+
"total_files": self.total_files,
|
| 85 |
+
"completed_files": self.completed_files,
|
| 86 |
+
"failed_files": self.failed_files,
|
| 87 |
+
"files": {
|
| 88 |
+
name: {
|
| 89 |
+
"filename": f.filename,
|
| 90 |
+
"status": f.status.value,
|
| 91 |
+
"progress": f.progress,
|
| 92 |
+
"transcript": f.transcript[:500] + "..." if f.transcript and len(f.transcript) > 500 else f.transcript,
|
| 93 |
+
"language": f.language,
|
| 94 |
+
"duration": f.duration,
|
| 95 |
+
"word_count": f.word_count,
|
| 96 |
+
"processing_time": f.processing_time,
|
| 97 |
+
"error": f.error,
|
| 98 |
+
}
|
| 99 |
+
for name, f in self.files.items()
|
| 100 |
+
},
|
| 101 |
+
"options": self.options,
|
| 102 |
+
"has_zip": self.output_zip_path is not None,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# In-memory job store (use Redis in production)
|
| 107 |
+
_batch_jobs: Dict[str, BatchJob] = {}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class BatchProcessingService:
|
| 111 |
+
"""
|
| 112 |
+
Service for batch audio transcription.
|
| 113 |
+
Processes multiple files with progress tracking.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
def __init__(self, output_dir: Optional[str] = None):
|
| 117 |
+
"""Initialize batch service."""
|
| 118 |
+
self.output_dir = output_dir or tempfile.gettempdir()
|
| 119 |
+
self._processing_lock = asyncio.Lock()
|
| 120 |
+
|
| 121 |
+
def create_job(
|
| 122 |
+
self,
|
| 123 |
+
filenames: List[str],
|
| 124 |
+
options: Optional[Dict[str, Any]] = None,
|
| 125 |
+
) -> BatchJob:
|
| 126 |
+
"""
|
| 127 |
+
Create a new batch job.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
filenames: List of filenames to process
|
| 131 |
+
options: Processing options (language, output_format, etc.)
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Created BatchJob
|
| 135 |
+
"""
|
| 136 |
+
job_id = str(uuid.uuid4())[:8]
|
| 137 |
+
|
| 138 |
+
files = {
|
| 139 |
+
name: FileResult(filename=name)
|
| 140 |
+
for name in filenames
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
job = BatchJob(
|
| 144 |
+
job_id=job_id,
|
| 145 |
+
files=files,
|
| 146 |
+
total_files=len(filenames),
|
| 147 |
+
options=options or {},
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
_batch_jobs[job_id] = job
|
| 151 |
+
logger.info(f"Created batch job {job_id} with {len(filenames)} files")
|
| 152 |
+
|
| 153 |
+
return job
|
| 154 |
+
|
| 155 |
+
def get_job(self, job_id: str) -> Optional[BatchJob]:
|
| 156 |
+
"""Get job by ID."""
|
| 157 |
+
return _batch_jobs.get(job_id)
|
| 158 |
+
|
| 159 |
+
def list_jobs(self, limit: int = 20) -> List[BatchJob]:
|
| 160 |
+
"""List recent jobs."""
|
| 161 |
+
jobs = list(_batch_jobs.values())
|
| 162 |
+
jobs.sort(key=lambda j: j.created_at, reverse=True)
|
| 163 |
+
return jobs[:limit]
|
| 164 |
+
|
| 165 |
+
async def process_job(
|
| 166 |
+
self,
|
| 167 |
+
job_id: str,
|
| 168 |
+
file_paths: Dict[str, str],
|
| 169 |
+
) -> BatchJob:
|
| 170 |
+
"""
|
| 171 |
+
Process all files in a batch job.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
job_id: Job ID
|
| 175 |
+
file_paths: Mapping of filename -> temp file path
|
| 176 |
+
|
| 177 |
+
Returns:
|
| 178 |
+
Completed BatchJob
|
| 179 |
+
"""
|
| 180 |
+
job = self.get_job(job_id)
|
| 181 |
+
if not job:
|
| 182 |
+
raise ValueError(f"Job not found: {job_id}")
|
| 183 |
+
|
| 184 |
+
job.status = JobStatus.PROCESSING
|
| 185 |
+
job.started_at = datetime.now()
|
| 186 |
+
|
| 187 |
+
# STT Service is used inside the worker now
|
| 188 |
+
# from app.services.whisper_stt_service import get_whisper_stt_service
|
| 189 |
+
# stt_service = get_whisper_stt_service()
|
| 190 |
+
|
| 191 |
+
# Get options
|
| 192 |
+
language = job.options.get("language")
|
| 193 |
+
output_format = job.options.get("output_format", "txt")
|
| 194 |
+
|
| 195 |
+
# Process each file
|
| 196 |
+
output_files: List[str] = []
|
| 197 |
+
|
| 198 |
+
for filename, file_path in file_paths.items():
|
| 199 |
+
file_result = job.files.get(filename)
|
| 200 |
+
if not file_result:
|
| 201 |
+
continue
|
| 202 |
+
|
| 203 |
+
file_result.status = FileStatus.PROCESSING
|
| 204 |
+
file_result.progress = 0.0
|
| 205 |
+
|
| 206 |
+
try:
|
| 207 |
+
import time
|
| 208 |
+
start_time = time.time()
|
| 209 |
+
|
| 210 |
+
# Transcribe via Celery Worker
|
| 211 |
+
from app.workers.tasks import transcribe_file_path
|
| 212 |
+
|
| 213 |
+
# Dispatch task
|
| 214 |
+
task = transcribe_file_path.delay(
|
| 215 |
+
file_path=file_path,
|
| 216 |
+
language=language,
|
| 217 |
+
output_format=output_format
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
# Wait for result (since this service runs in background thread)
|
| 221 |
+
# In a full async arch we would return job_id and poll,
|
| 222 |
+
# but here we keep the batch logic simple while scaling the compute.
|
| 223 |
+
task_result = task.get(timeout=600) # 10 min timeout per file
|
| 224 |
+
|
| 225 |
+
processing_time = time.time() - start_time
|
| 226 |
+
|
| 227 |
+
# Update file result
|
| 228 |
+
file_result.transcript = task_result.get("text", "")
|
| 229 |
+
file_result.language = task_result.get("language", "unknown")
|
| 230 |
+
file_result.duration = task_result.get("duration")
|
| 231 |
+
file_result.word_count = len(file_result.transcript.split())
|
| 232 |
+
file_result.processing_time = round(processing_time, 2)
|
| 233 |
+
file_result.status = FileStatus.COMPLETED
|
| 234 |
+
file_result.progress = 100.0
|
| 235 |
+
|
| 236 |
+
# Helper for SRT writing since we have raw segments dicts now
|
| 237 |
+
result = {"segments": task_result.get("segments", []), "text": file_result.transcript}
|
| 238 |
+
|
| 239 |
+
# Save output file
|
| 240 |
+
output_filename = Path(filename).stem + f".{output_format}"
|
| 241 |
+
output_path = os.path.join(self.output_dir, job_id, output_filename)
|
| 242 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 243 |
+
|
| 244 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
| 245 |
+
if output_format == "srt":
|
| 246 |
+
# Write SRT format
|
| 247 |
+
segments = result.get("segments", [])
|
| 248 |
+
for i, seg in enumerate(segments, 1):
|
| 249 |
+
start = self._format_srt_time(seg.get("start", 0))
|
| 250 |
+
end = self._format_srt_time(seg.get("end", 0))
|
| 251 |
+
text = seg.get("text", "").strip()
|
| 252 |
+
f.write(f"{i}\n{start} --> {end}\n{text}\n\n")
|
| 253 |
+
else:
|
| 254 |
+
f.write(file_result.transcript)
|
| 255 |
+
|
| 256 |
+
file_result.output_path = output_path
|
| 257 |
+
output_files.append(output_path)
|
| 258 |
+
|
| 259 |
+
job.completed_files += 1
|
| 260 |
+
logger.info(f"[{job_id}] Completed {filename} ({job.completed_files}/{job.total_files})")
|
| 261 |
+
|
| 262 |
+
except Exception as e:
|
| 263 |
+
file_result.status = FileStatus.FAILED
|
| 264 |
+
file_result.error = str(e)
|
| 265 |
+
file_result.progress = 0.0
|
| 266 |
+
job.failed_files += 1
|
| 267 |
+
logger.error(f"[{job_id}] Failed {filename}: {e}")
|
| 268 |
+
|
| 269 |
+
finally:
|
| 270 |
+
# Clean up temp file
|
| 271 |
+
try:
|
| 272 |
+
if os.path.exists(file_path):
|
| 273 |
+
os.unlink(file_path)
|
| 274 |
+
except:
|
| 275 |
+
pass
|
| 276 |
+
|
| 277 |
+
# Create ZIP of all outputs
|
| 278 |
+
if output_files:
|
| 279 |
+
zip_path = os.path.join(self.output_dir, f"{job_id}_results.zip")
|
| 280 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
|
| 281 |
+
for file_path in output_files:
|
| 282 |
+
zf.write(file_path, os.path.basename(file_path))
|
| 283 |
+
|
| 284 |
+
job.output_zip_path = zip_path
|
| 285 |
+
logger.info(f"[{job_id}] Created ZIP: {zip_path}")
|
| 286 |
+
|
| 287 |
+
# Update job status
|
| 288 |
+
job.status = JobStatus.COMPLETED if job.failed_files == 0 else JobStatus.FAILED
|
| 289 |
+
job.completed_at = datetime.now()
|
| 290 |
+
|
| 291 |
+
return job
|
| 292 |
+
|
| 293 |
+
def _format_srt_time(self, seconds: float) -> str:
|
| 294 |
+
"""Format seconds to SRT time format (HH:MM:SS,mmm)."""
|
| 295 |
+
hours = int(seconds // 3600)
|
| 296 |
+
minutes = int((seconds % 3600) // 60)
|
| 297 |
+
secs = int(seconds % 60)
|
| 298 |
+
millis = int((seconds % 1) * 1000)
|
| 299 |
+
return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
|
| 300 |
+
|
| 301 |
+
def cancel_job(self, job_id: str) -> bool:
|
| 302 |
+
"""Cancel a pending/processing job."""
|
| 303 |
+
job = self.get_job(job_id)
|
| 304 |
+
if job and job.status in [JobStatus.PENDING, JobStatus.PROCESSING]:
|
| 305 |
+
job.status = JobStatus.CANCELLED
|
| 306 |
+
return True
|
| 307 |
+
return False
|
| 308 |
+
|
| 309 |
+
def delete_job(self, job_id: str) -> bool:
|
| 310 |
+
"""Delete a job and its output files."""
|
| 311 |
+
job = _batch_jobs.pop(job_id, None)
|
| 312 |
+
if job:
|
| 313 |
+
# Clean up files
|
| 314 |
+
if job.output_zip_path and os.path.exists(job.output_zip_path):
|
| 315 |
+
try:
|
| 316 |
+
os.unlink(job.output_zip_path)
|
| 317 |
+
except:
|
| 318 |
+
pass
|
| 319 |
+
|
| 320 |
+
job_dir = os.path.join(self.output_dir, job_id)
|
| 321 |
+
if os.path.exists(job_dir):
|
| 322 |
+
try:
|
| 323 |
+
import shutil
|
| 324 |
+
shutil.rmtree(job_dir)
|
| 325 |
+
except:
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
return True
|
| 329 |
+
return False
|
| 330 |
+
|
| 331 |
+
def get_zip_path(self, job_id: str) -> Optional[str]:
|
| 332 |
+
"""Get path to job's output ZIP file."""
|
| 333 |
+
job = self.get_job(job_id)
|
| 334 |
+
if job and job.output_zip_path and os.path.exists(job.output_zip_path):
|
| 335 |
+
return job.output_zip_path
|
| 336 |
+
return None
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
# Singleton instance
|
| 340 |
+
_batch_service: Optional[BatchProcessingService] = None
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def get_batch_service() -> BatchProcessingService:
|
| 344 |
+
"""Get or create BatchProcessingService singleton."""
|
| 345 |
+
global _batch_service
|
| 346 |
+
if _batch_service is None:
|
| 347 |
+
_batch_service = BatchProcessingService()
|
| 348 |
+
return _batch_service
|
backend/app/services/cache_service.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import redis
|
| 2 |
+
import json
|
| 3 |
+
import hashlib
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Optional, Any
|
| 6 |
+
from functools import lru_cache
|
| 7 |
+
|
| 8 |
+
from ..core.config import get_settings
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class CacheService:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
settings = get_settings()
|
| 15 |
+
self.default_ttl = 3600 # 1 hour
|
| 16 |
+
self.redis = None
|
| 17 |
+
self.disk_cache = None
|
| 18 |
+
|
| 19 |
+
# Try Redis first
|
| 20 |
+
try:
|
| 21 |
+
self.redis = redis.from_url(settings.redis_url, decode_responses=False)
|
| 22 |
+
self.redis.ping()
|
| 23 |
+
logger.info("✅ Redis Cache connected")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logger.warning(f"⚠️ Redis unavailable, falling back to DiskCache: {e}")
|
| 26 |
+
self.redis = None
|
| 27 |
+
|
| 28 |
+
# Fallback to DiskCache
|
| 29 |
+
try:
|
| 30 |
+
import diskcache
|
| 31 |
+
cache_dir = "./cache_data"
|
| 32 |
+
self.disk_cache = diskcache.Cache(cache_dir)
|
| 33 |
+
logger.info(f"💾 DiskCache initialized at {cache_dir}")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
logger.error(f"❌ DiskCache init failed: {e}")
|
| 36 |
+
|
| 37 |
+
def get(self, key: str) -> Optional[bytes]:
|
| 38 |
+
"""Get raw bytes from cache"""
|
| 39 |
+
try:
|
| 40 |
+
if self.redis:
|
| 41 |
+
return self.redis.get(key)
|
| 42 |
+
elif self.disk_cache:
|
| 43 |
+
return self.disk_cache.get(key)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Cache get failed: {e}")
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
def set(self, key: str, value: bytes, ttl: int = None):
|
| 49 |
+
"""Set raw bytes in cache"""
|
| 50 |
+
try:
|
| 51 |
+
ttl_val = ttl or self.default_ttl
|
| 52 |
+
|
| 53 |
+
if self.redis:
|
| 54 |
+
self.redis.setex(key, ttl_val, value)
|
| 55 |
+
elif self.disk_cache:
|
| 56 |
+
self.disk_cache.set(key, value, expire=ttl_val)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Cache set failed: {e}")
|
| 59 |
+
|
| 60 |
+
def generate_key(self, prefix: str, **kwargs) -> str:
|
| 61 |
+
"""Generate a stable cache key from arguments"""
|
| 62 |
+
# Convert all values to string for stability
|
| 63 |
+
safe_kwargs = {k: str(v) for k, v in kwargs.items()}
|
| 64 |
+
sorted_kwargs = dict(sorted(safe_kwargs.items()))
|
| 65 |
+
key_str = json.dumps(sorted_kwargs, sort_keys=True)
|
| 66 |
+
hash_str = hashlib.md5(key_str.encode()).hexdigest()
|
| 67 |
+
return f"{prefix}:{hash_str}"
|
| 68 |
+
|
| 69 |
+
@lru_cache()
|
| 70 |
+
def get_cache_service() -> CacheService:
|
| 71 |
+
return CacheService()
|
backend/app/services/clone_service.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Voice Cloning Service (Coqui XTTS)
|
| 3 |
+
High-quality multi-lingual text-to-speech with voice cloning capabilities.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import logging
|
| 8 |
+
import torch
|
| 9 |
+
import gc
|
| 10 |
+
from typing import List, Optional, Dict, Any
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import tempfile
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
class CloneService:
|
| 17 |
+
"""
|
| 18 |
+
Service for Voice Cloning using Coqui XTTS v2.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 23 |
+
self.tts = None
|
| 24 |
+
self.model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
|
| 25 |
+
self.loaded = False
|
| 26 |
+
|
| 27 |
+
def load_model(self):
|
| 28 |
+
"""Lazy load the heavy XTTS model"""
|
| 29 |
+
if self.loaded:
|
| 30 |
+
return
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
logger.info(f"Loading XTTS model ({self.device})... This may take a while.")
|
| 34 |
+
from TTS.api import TTS
|
| 35 |
+
|
| 36 |
+
# Load model
|
| 37 |
+
self.tts = TTS(self.model_name).to(self.device)
|
| 38 |
+
self.loaded = True
|
| 39 |
+
logger.info("✅ XTTS Model loaded successfully")
|
| 40 |
+
|
| 41 |
+
except ImportError as e:
|
| 42 |
+
logger.error("TTS library not installed. Please install 'TTS'.")
|
| 43 |
+
raise ImportError("Voice Cloning requires 'TTS' library.")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"Failed to load XTTS model: {e}")
|
| 46 |
+
raise e
|
| 47 |
+
|
| 48 |
+
def unload_model(self):
|
| 49 |
+
"""Unload model to free VRAM"""
|
| 50 |
+
if self.tts:
|
| 51 |
+
del self.tts
|
| 52 |
+
self.tts = None
|
| 53 |
+
self.loaded = False
|
| 54 |
+
gc.collect()
|
| 55 |
+
torch.cuda.empty_cache()
|
| 56 |
+
logger.info("🗑️ XTTS Model unloaded")
|
| 57 |
+
|
| 58 |
+
def clone_voice(
|
| 59 |
+
self,
|
| 60 |
+
text: str,
|
| 61 |
+
speaker_wav_paths: List[str],
|
| 62 |
+
language: str = "en",
|
| 63 |
+
output_path: Optional[str] = None
|
| 64 |
+
) -> str:
|
| 65 |
+
"""
|
| 66 |
+
Synthesize speech in the style of the reference audio.
|
| 67 |
+
"""
|
| 68 |
+
if not self.loaded:
|
| 69 |
+
self.load_model()
|
| 70 |
+
|
| 71 |
+
if not output_path:
|
| 72 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
|
| 73 |
+
output_path = f.name
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
# XTTS synthesis
|
| 77 |
+
# Note: speaker_wav can be a list of files for better cloning
|
| 78 |
+
self.tts.tts_to_file(
|
| 79 |
+
text=text,
|
| 80 |
+
speaker_wav=speaker_wav_paths,
|
| 81 |
+
language=language,
|
| 82 |
+
file_path=output_path,
|
| 83 |
+
split_sentences=True
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
logger.info(f"Cloned speech generated: {output_path}")
|
| 87 |
+
return output_path
|
| 88 |
+
|
| 89 |
+
except Exception as e:
|
| 90 |
+
logger.error(f"Cloning failed: {e}")
|
| 91 |
+
raise e
|
| 92 |
+
|
| 93 |
+
def get_supported_languages(self) -> List[str]:
|
| 94 |
+
# XTTS v2 supported languages
|
| 95 |
+
return ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "hu", "ko"]
|
| 96 |
+
|
| 97 |
+
# Singleton
|
| 98 |
+
_clone_service = None
|
| 99 |
+
|
| 100 |
+
def get_clone_service():
|
| 101 |
+
global _clone_service
|
| 102 |
+
if _clone_service is None:
|
| 103 |
+
_clone_service = CloneService()
|
| 104 |
+
return _clone_service
|
backend/app/services/diarization_service.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Speaker Diarization Service - Clean Implementation
|
| 3 |
+
Uses faster-whisper + pyannote.audio directly (no whisperx)
|
| 4 |
+
|
| 5 |
+
This avoids the KeyError bugs in whisperx alignment while providing
|
| 6 |
+
the same functionality.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import gc
|
| 11 |
+
import logging
|
| 12 |
+
import torch
|
| 13 |
+
from typing import Optional, Dict, Any, List
|
| 14 |
+
from dotenv import load_dotenv
|
| 15 |
+
|
| 16 |
+
from app.core.config import get_settings
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
# Load environment variables from .env file
|
| 21 |
+
load_dotenv()
|
| 22 |
+
|
| 23 |
+
# Workaround for PyTorch 2.6+ weights_only security restriction
|
| 24 |
+
os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DiarizationService:
|
| 28 |
+
"""
|
| 29 |
+
Speaker Diarization Service using faster-whisper + pyannote.audio.
|
| 30 |
+
|
| 31 |
+
This implementation avoids whisperx entirely to prevent alignment bugs.
|
| 32 |
+
|
| 33 |
+
Flow:
|
| 34 |
+
1. Transcribe with faster-whisper (word-level timestamps)
|
| 35 |
+
2. Diarize with pyannote.audio (speaker segments)
|
| 36 |
+
3. Merge speakers with transcript segments
|
| 37 |
+
|
| 38 |
+
Requires:
|
| 39 |
+
- faster-whisper (already installed)
|
| 40 |
+
- pyannote.audio
|
| 41 |
+
- Valid Hugging Face Token (HF_TOKEN) in .env
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self):
|
| 45 |
+
self.settings = get_settings()
|
| 46 |
+
|
| 47 |
+
# Auto-detect GPU (prefer CUDA for speed)
|
| 48 |
+
if torch.cuda.is_available():
|
| 49 |
+
self.device = "cuda"
|
| 50 |
+
self.compute_type = "float16"
|
| 51 |
+
logger.info(f"🚀 Diarization using GPU: {torch.cuda.get_device_name(0)}")
|
| 52 |
+
else:
|
| 53 |
+
self.device = "cpu"
|
| 54 |
+
self.compute_type = "int8"
|
| 55 |
+
logger.info("⚠️ Diarization using CPU (slower)")
|
| 56 |
+
|
| 57 |
+
# Load HF token
|
| 58 |
+
self.hf_token = os.getenv("HF_TOKEN")
|
| 59 |
+
if not self.hf_token:
|
| 60 |
+
logger.warning("⚠️ HF_TOKEN not found. Speaker diarization will fail.")
|
| 61 |
+
|
| 62 |
+
# FFmpeg Setup for Windows
|
| 63 |
+
self._setup_ffmpeg()
|
| 64 |
+
|
| 65 |
+
def _setup_ffmpeg(self):
|
| 66 |
+
"""Auto-configure FFmpeg from imageio-ffmpeg if not in PATH"""
|
| 67 |
+
try:
|
| 68 |
+
import imageio_ffmpeg
|
| 69 |
+
import shutil
|
| 70 |
+
|
| 71 |
+
ffmpeg_src = imageio_ffmpeg.get_ffmpeg_exe()
|
| 72 |
+
backend_dir = os.getcwd()
|
| 73 |
+
ffmpeg_dest = os.path.join(backend_dir, "ffmpeg.exe")
|
| 74 |
+
|
| 75 |
+
if not os.path.exists(ffmpeg_dest):
|
| 76 |
+
shutil.copy(ffmpeg_src, ffmpeg_dest)
|
| 77 |
+
logger.info(f"🔧 Configured FFmpeg: {ffmpeg_dest}")
|
| 78 |
+
|
| 79 |
+
if backend_dir not in os.environ.get("PATH", ""):
|
| 80 |
+
os.environ["PATH"] = backend_dir + os.pathsep + os.environ.get("PATH", "")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.warning(f"⚠️ Could not auto-configure FFmpeg: {e}")
|
| 84 |
+
|
| 85 |
+
def check_requirements(self):
|
| 86 |
+
"""Validate requirements before processing"""
|
| 87 |
+
if not self.hf_token:
|
| 88 |
+
raise ValueError(
|
| 89 |
+
"HF_TOKEN is missing. Add HF_TOKEN=your_token to .env file. "
|
| 90 |
+
"Get one at: https://huggingface.co/settings/tokens"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def _get_diarization_pipeline(self):
|
| 94 |
+
"""Load pyannote diarization pipeline with PyTorch 2.6+ fix"""
|
| 95 |
+
from pyannote.audio import Pipeline
|
| 96 |
+
|
| 97 |
+
# Monkey-patch torch.load for PyTorch 2.6+ compatibility
|
| 98 |
+
original_load = torch.load
|
| 99 |
+
def safe_load(*args, **kwargs):
|
| 100 |
+
kwargs.pop('weights_only', None)
|
| 101 |
+
return original_load(*args, **kwargs, weights_only=False)
|
| 102 |
+
|
| 103 |
+
torch.load = safe_load
|
| 104 |
+
try:
|
| 105 |
+
pipeline = Pipeline.from_pretrained(
|
| 106 |
+
"pyannote/speaker-diarization-3.1",
|
| 107 |
+
use_auth_token=self.hf_token
|
| 108 |
+
)
|
| 109 |
+
if self.device == "cuda":
|
| 110 |
+
pipeline.to(torch.device("cuda"))
|
| 111 |
+
return pipeline
|
| 112 |
+
finally:
|
| 113 |
+
torch.load = original_load
|
| 114 |
+
|
| 115 |
+
def _transcribe_with_timestamps(self, audio_path: str, language: Optional[str] = None) -> Dict:
|
| 116 |
+
"""Transcribe audio using faster-whisper with word timestamps"""
|
| 117 |
+
from faster_whisper import WhisperModel
|
| 118 |
+
|
| 119 |
+
# CTranslate2 (faster-whisper) doesn't support float16 on all GPUs
|
| 120 |
+
# Use int8 for whisper, but pyannote still benefits from CUDA
|
| 121 |
+
whisper_compute = "int8" if self.device == "cuda" else "int8"
|
| 122 |
+
model = WhisperModel(
|
| 123 |
+
"small",
|
| 124 |
+
device=self.device,
|
| 125 |
+
compute_type=whisper_compute
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
segments_raw, info = model.transcribe(
|
| 129 |
+
audio_path,
|
| 130 |
+
language=language,
|
| 131 |
+
word_timestamps=True,
|
| 132 |
+
vad_filter=True
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
segments = []
|
| 136 |
+
for segment in segments_raw:
|
| 137 |
+
segments.append({
|
| 138 |
+
"start": segment.start,
|
| 139 |
+
"end": segment.end,
|
| 140 |
+
"text": segment.text.strip(),
|
| 141 |
+
"words": [
|
| 142 |
+
{"start": w.start, "end": w.end, "word": w.word}
|
| 143 |
+
for w in (segment.words or [])
|
| 144 |
+
]
|
| 145 |
+
})
|
| 146 |
+
|
| 147 |
+
# Cleanup
|
| 148 |
+
del model
|
| 149 |
+
gc.collect()
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"segments": segments,
|
| 153 |
+
"language": info.language
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
def _preprocess_audio(self, audio_path: str) -> str:
|
| 157 |
+
"""
|
| 158 |
+
Apply noise reduction to audio file.
|
| 159 |
+
Returns path to cleaned audio file.
|
| 160 |
+
"""
|
| 161 |
+
try:
|
| 162 |
+
import noisereduce as nr
|
| 163 |
+
import librosa
|
| 164 |
+
import soundfile as sf
|
| 165 |
+
import tempfile
|
| 166 |
+
|
| 167 |
+
logger.info("🔧 Preprocessing audio (noise reduction)...")
|
| 168 |
+
|
| 169 |
+
# Load audio
|
| 170 |
+
audio, sr = librosa.load(audio_path, sr=16000, mono=True)
|
| 171 |
+
|
| 172 |
+
# Apply spectral gating noise reduction
|
| 173 |
+
reduced_noise = nr.reduce_noise(
|
| 174 |
+
y=audio,
|
| 175 |
+
sr=sr,
|
| 176 |
+
stationary=True,
|
| 177 |
+
prop_decrease=0.75
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Save to temp file
|
| 181 |
+
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
| 182 |
+
sf.write(temp_file.name, reduced_noise, sr)
|
| 183 |
+
|
| 184 |
+
logger.info(f" → Noise reduction complete, saved to {temp_file.name}")
|
| 185 |
+
return temp_file.name
|
| 186 |
+
|
| 187 |
+
except ImportError as e:
|
| 188 |
+
logger.warning(f"⚠️ Audio preprocessing unavailable (install noisereduce, librosa, soundfile): {e}")
|
| 189 |
+
return audio_path
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logger.warning(f"⚠️ Audio preprocessing failed: {e}")
|
| 192 |
+
return audio_path
|
| 193 |
+
|
| 194 |
+
def _merge_speakers(self, transcript: Dict, diarization) -> List[Dict]:
|
| 195 |
+
"""
|
| 196 |
+
Merge speaker labels from diarization with transcript segments.
|
| 197 |
+
|
| 198 |
+
Uses midpoint matching with nearest-speaker fallback to minimize UNKNOWN labels.
|
| 199 |
+
"""
|
| 200 |
+
segments = transcript["segments"]
|
| 201 |
+
result = []
|
| 202 |
+
|
| 203 |
+
# Build list of speaker turns for efficient lookup
|
| 204 |
+
speaker_turns = [
|
| 205 |
+
(turn.start, turn.end, spk)
|
| 206 |
+
for turn, _, spk in diarization.itertracks(yield_label=True)
|
| 207 |
+
]
|
| 208 |
+
|
| 209 |
+
for seg in segments:
|
| 210 |
+
mid_time = (seg["start"] + seg["end"]) / 2
|
| 211 |
+
speaker = None
|
| 212 |
+
|
| 213 |
+
# Step 1: Try exact midpoint match
|
| 214 |
+
for start, end, spk in speaker_turns:
|
| 215 |
+
if start <= mid_time <= end:
|
| 216 |
+
speaker = spk
|
| 217 |
+
break
|
| 218 |
+
|
| 219 |
+
# Step 2: If no match, find nearest speaker (fallback)
|
| 220 |
+
if speaker is None and speaker_turns:
|
| 221 |
+
min_distance = float('inf')
|
| 222 |
+
for start, end, spk in speaker_turns:
|
| 223 |
+
# Distance to nearest edge of speaker segment
|
| 224 |
+
if mid_time < start:
|
| 225 |
+
dist = start - mid_time
|
| 226 |
+
elif mid_time > end:
|
| 227 |
+
dist = mid_time - end
|
| 228 |
+
else:
|
| 229 |
+
dist = 0 # Should have been caught above
|
| 230 |
+
|
| 231 |
+
if dist < min_distance:
|
| 232 |
+
min_distance = dist
|
| 233 |
+
speaker = spk
|
| 234 |
+
|
| 235 |
+
# Final fallback (shouldn't happen)
|
| 236 |
+
if speaker is None:
|
| 237 |
+
speaker = "UNKNOWN"
|
| 238 |
+
|
| 239 |
+
result.append({
|
| 240 |
+
"start": seg["start"],
|
| 241 |
+
"end": seg["end"],
|
| 242 |
+
"text": seg["text"],
|
| 243 |
+
"speaker": speaker
|
| 244 |
+
})
|
| 245 |
+
|
| 246 |
+
return result
|
| 247 |
+
|
| 248 |
+
def process_audio(
|
| 249 |
+
self,
|
| 250 |
+
audio_path: str,
|
| 251 |
+
num_speakers: Optional[int] = None,
|
| 252 |
+
min_speakers: Optional[int] = None,
|
| 253 |
+
max_speakers: Optional[int] = None,
|
| 254 |
+
language: Optional[str] = None,
|
| 255 |
+
preprocess: bool = False,
|
| 256 |
+
) -> Dict[str, Any]:
|
| 257 |
+
"""
|
| 258 |
+
Full diarization pipeline: [Preprocess] → Transcribe → Diarize → Merge
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
audio_path: Path to audio file
|
| 262 |
+
num_speakers: Exact number of speakers (optional)
|
| 263 |
+
min_speakers: Minimum speakers (optional)
|
| 264 |
+
max_speakers: Maximum speakers (optional)
|
| 265 |
+
language: Force language code (optional, auto-detected if None)
|
| 266 |
+
preprocess: Apply noise reduction before processing (default: False)
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Dict with segments, speaker_stats, language, status
|
| 270 |
+
"""
|
| 271 |
+
self.check_requirements()
|
| 272 |
+
|
| 273 |
+
logger.info(f"🎤 Starting diarization on {self.device}...")
|
| 274 |
+
|
| 275 |
+
# Optional preprocessing for noise reduction
|
| 276 |
+
processed_path = audio_path
|
| 277 |
+
if preprocess:
|
| 278 |
+
processed_path = self._preprocess_audio(audio_path)
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
# Step 1: Transcribe with faster-whisper
|
| 282 |
+
logger.info("Step 1/3: Transcribing audio...")
|
| 283 |
+
transcript = self._transcribe_with_timestamps(processed_path, language)
|
| 284 |
+
detected_lang = transcript["language"]
|
| 285 |
+
logger.info(f" → Language: {detected_lang}, Segments: {len(transcript['segments'])}")
|
| 286 |
+
|
| 287 |
+
# Step 2: Diarize with pyannote
|
| 288 |
+
logger.info("Step 2/3: Identifying speakers...")
|
| 289 |
+
pipeline = self._get_diarization_pipeline()
|
| 290 |
+
|
| 291 |
+
diarization = pipeline(
|
| 292 |
+
processed_path,
|
| 293 |
+
num_speakers=num_speakers,
|
| 294 |
+
min_speakers=min_speakers,
|
| 295 |
+
max_speakers=max_speakers
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
# Cleanup pipeline
|
| 299 |
+
del pipeline
|
| 300 |
+
gc.collect()
|
| 301 |
+
|
| 302 |
+
# Step 3: Merge results
|
| 303 |
+
logger.info("Step 3/3: Merging speakers with transcript...")
|
| 304 |
+
segments = self._merge_speakers(transcript, diarization)
|
| 305 |
+
|
| 306 |
+
# Calculate speaker stats
|
| 307 |
+
speaker_stats = {}
|
| 308 |
+
for seg in segments:
|
| 309 |
+
spk = seg["speaker"]
|
| 310 |
+
dur = seg["end"] - seg["start"]
|
| 311 |
+
speaker_stats[spk] = speaker_stats.get(spk, 0) + dur
|
| 312 |
+
|
| 313 |
+
logger.info(f"✅ Diarization complete: {len(segments)} segments, {len(speaker_stats)} speakers")
|
| 314 |
+
|
| 315 |
+
return {
|
| 316 |
+
"segments": segments,
|
| 317 |
+
"speaker_stats": speaker_stats,
|
| 318 |
+
"language": detected_lang,
|
| 319 |
+
"status": "success"
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
except Exception as e:
|
| 323 |
+
logger.exception("Diarization failed")
|
| 324 |
+
raise e
|
| 325 |
+
finally:
|
| 326 |
+
gc.collect()
|
| 327 |
+
if self.device == "cuda":
|
| 328 |
+
torch.cuda.empty_cache()
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Singleton
|
| 332 |
+
_diarization_service = None
|
| 333 |
+
|
| 334 |
+
def get_diarization_service():
|
| 335 |
+
global _diarization_service
|
| 336 |
+
if not _diarization_service:
|
| 337 |
+
_diarization_service = DiarizationService()
|
| 338 |
+
return _diarization_service
|
backend/app/services/edge_tts_service.py
ADDED
|
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Edge-TTS Text-to-Speech Service
|
| 3 |
+
Free, high-quality neural TTS using Microsoft Edge's speech synthesis
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import asyncio
|
| 7 |
+
import io
|
| 8 |
+
import logging
|
| 9 |
+
import edge_tts
|
| 10 |
+
from typing import Optional, List, Dict, Any
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Available voice samples by language
|
| 16 |
+
VOICE_CATALOG = {
|
| 17 |
+
"en-US": [
|
| 18 |
+
{"name": "en-US-AriaNeural", "gender": "Female", "style": "professional"},
|
| 19 |
+
{"name": "en-US-GuyNeural", "gender": "Male", "style": "casual"},
|
| 20 |
+
{"name": "en-US-JennyNeural", "gender": "Female", "style": "friendly"},
|
| 21 |
+
{"name": "en-US-ChristopherNeural", "gender": "Male", "style": "newscast"},
|
| 22 |
+
],
|
| 23 |
+
"en-GB": [
|
| 24 |
+
{"name": "en-GB-SoniaNeural", "gender": "Female", "style": "professional"},
|
| 25 |
+
{"name": "en-GB-RyanNeural", "gender": "Male", "style": "casual"},
|
| 26 |
+
],
|
| 27 |
+
"en-IN": [
|
| 28 |
+
{"name": "en-IN-NeerjaNeural", "gender": "Female", "style": "professional"},
|
| 29 |
+
{"name": "en-IN-PrabhatNeural", "gender": "Male", "style": "casual"},
|
| 30 |
+
],
|
| 31 |
+
"hi-IN": [
|
| 32 |
+
{"name": "hi-IN-SwaraNeural", "gender": "Female", "style": "professional"},
|
| 33 |
+
{"name": "hi-IN-MadhurNeural", "gender": "Male", "style": "casual"},
|
| 34 |
+
],
|
| 35 |
+
"es-ES": [
|
| 36 |
+
{"name": "es-ES-ElviraNeural", "gender": "Female", "style": "professional"},
|
| 37 |
+
{"name": "es-ES-AlvaroNeural", "gender": "Male", "style": "casual"},
|
| 38 |
+
],
|
| 39 |
+
"es-MX": [
|
| 40 |
+
{"name": "es-MX-DaliaNeural", "gender": "Female", "style": "professional"},
|
| 41 |
+
{"name": "es-MX-JorgeNeural", "gender": "Male", "style": "casual"},
|
| 42 |
+
],
|
| 43 |
+
"fr-FR": [
|
| 44 |
+
{"name": "fr-FR-DeniseNeural", "gender": "Female", "style": "professional"},
|
| 45 |
+
{"name": "fr-FR-HenriNeural", "gender": "Male", "style": "casual"},
|
| 46 |
+
],
|
| 47 |
+
"de-DE": [
|
| 48 |
+
{"name": "de-DE-KatjaNeural", "gender": "Female", "style": "professional"},
|
| 49 |
+
{"name": "de-DE-ConradNeural", "gender": "Male", "style": "casual"},
|
| 50 |
+
],
|
| 51 |
+
"ja-JP": [
|
| 52 |
+
{"name": "ja-JP-NanamiNeural", "gender": "Female", "style": "professional"},
|
| 53 |
+
{"name": "ja-JP-KeitaNeural", "gender": "Male", "style": "casual"},
|
| 54 |
+
],
|
| 55 |
+
"ko-KR": [
|
| 56 |
+
{"name": "ko-KR-SunHiNeural", "gender": "Female", "style": "professional"},
|
| 57 |
+
{"name": "ko-KR-InJoonNeural", "gender": "Male", "style": "casual"},
|
| 58 |
+
],
|
| 59 |
+
"zh-CN": [
|
| 60 |
+
{"name": "zh-CN-XiaoxiaoNeural", "gender": "Female", "style": "professional"},
|
| 61 |
+
{"name": "zh-CN-YunxiNeural", "gender": "Male", "style": "casual"},
|
| 62 |
+
],
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class EdgeTTSService:
|
| 67 |
+
"""
|
| 68 |
+
Text-to-Speech service using Microsoft Edge TTS (free, neural voices)
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
def __init__(self):
|
| 72 |
+
"""Initialize the Edge TTS service"""
|
| 73 |
+
self._all_voices = None
|
| 74 |
+
|
| 75 |
+
# Class-level cache
|
| 76 |
+
_voices_cache = None
|
| 77 |
+
|
| 78 |
+
async def get_voices(self, language: Optional[str] = None) -> List[Dict[str, Any]]:
|
| 79 |
+
"""
|
| 80 |
+
Get available voices
|
| 81 |
+
"""
|
| 82 |
+
# Check cache
|
| 83 |
+
if EdgeTTSService._voices_cache is None:
|
| 84 |
+
try:
|
| 85 |
+
voices = await edge_tts.list_voices()
|
| 86 |
+
|
| 87 |
+
# Transform to our format
|
| 88 |
+
formatted_voices = []
|
| 89 |
+
for v in voices:
|
| 90 |
+
formatted_voices.append({
|
| 91 |
+
"name": v["ShortName"],
|
| 92 |
+
"display_name": v["ShortName"].replace("-", " ").split("Neural")[0].strip(),
|
| 93 |
+
"language_code": v["Locale"],
|
| 94 |
+
"gender": v["Gender"],
|
| 95 |
+
"voice_type": "Neural",
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
EdgeTTSService._voices_cache = formatted_voices
|
| 99 |
+
except Exception as e:
|
| 100 |
+
logger.error(f"Failed to fetch voices from Edge TTS: {e}. Falling back to catalog.")
|
| 101 |
+
# Fallback to catalog
|
| 102 |
+
voices = []
|
| 103 |
+
for lang, lang_voices in VOICE_CATALOG.items():
|
| 104 |
+
for v in lang_voices:
|
| 105 |
+
voices.append({
|
| 106 |
+
"name": v["name"],
|
| 107 |
+
"display_name": v["name"].replace("-", " ").replace("Neural", "").strip(),
|
| 108 |
+
"language_code": lang,
|
| 109 |
+
"gender": v["gender"],
|
| 110 |
+
"voice_type": "Neural",
|
| 111 |
+
})
|
| 112 |
+
EdgeTTSService._voices_cache = voices
|
| 113 |
+
|
| 114 |
+
voices = EdgeTTSService._voices_cache
|
| 115 |
+
|
| 116 |
+
# Filter by language if specified
|
| 117 |
+
if language:
|
| 118 |
+
voices = [v for v in voices if v["language_code"].startswith(language)]
|
| 119 |
+
|
| 120 |
+
return voices
|
| 121 |
+
|
| 122 |
+
def get_voices_sync(self, language: Optional[str] = None) -> List[Dict[str, Any]]:
|
| 123 |
+
"""Synchronous wrapper for get_voices"""
|
| 124 |
+
# Create a new event loop if necessary for sync wrapper
|
| 125 |
+
try:
|
| 126 |
+
loop = asyncio.get_event_loop()
|
| 127 |
+
except RuntimeError:
|
| 128 |
+
loop = asyncio.new_event_loop()
|
| 129 |
+
asyncio.set_event_loop(loop)
|
| 130 |
+
|
| 131 |
+
if loop.is_running():
|
| 132 |
+
# If loop is running, we can't block it.
|
| 133 |
+
import concurrent.futures
|
| 134 |
+
with concurrent.futures.ThreadPoolExecutor() as pool:
|
| 135 |
+
future = asyncio.run_coroutine_threadsafe(self.get_voices(language), loop)
|
| 136 |
+
return future.result()
|
| 137 |
+
|
| 138 |
+
return loop.run_until_complete(self.get_voices(language))
|
| 139 |
+
|
| 140 |
+
def build_ssml(
|
| 141 |
+
self,
|
| 142 |
+
text: str,
|
| 143 |
+
voice: str = "en-US-AriaNeural",
|
| 144 |
+
rate: str = "medium",
|
| 145 |
+
pitch: str = "medium",
|
| 146 |
+
emphasis: str = None,
|
| 147 |
+
breaks: bool = True
|
| 148 |
+
) -> str:
|
| 149 |
+
"""
|
| 150 |
+
Build SSML markup for advanced prosody control.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
text: Plain text to convert
|
| 154 |
+
voice: Voice name
|
| 155 |
+
rate: Speed - 'x-slow', 'slow', 'medium', 'fast', 'x-fast' or percentage
|
| 156 |
+
pitch: Pitch - 'x-low', 'low', 'medium', 'high', 'x-high' or Hz offset
|
| 157 |
+
emphasis: Optional emphasis level - 'reduced', 'moderate', 'strong'
|
| 158 |
+
breaks: Auto-insert breaks at punctuation
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
SSML-formatted string
|
| 162 |
+
"""
|
| 163 |
+
# Normalize rate/pitch values
|
| 164 |
+
rate_value = rate if rate in ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] else rate
|
| 165 |
+
pitch_value = pitch if pitch in ['x-low', 'low', 'medium', 'high', 'x-high'] else pitch
|
| 166 |
+
|
| 167 |
+
# Build SSML
|
| 168 |
+
ssml_parts = ['<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="en-US">']
|
| 169 |
+
ssml_parts.append(f'<voice name="{voice}">')
|
| 170 |
+
ssml_parts.append(f'<prosody rate="{rate_value}" pitch="{pitch_value}">')
|
| 171 |
+
|
| 172 |
+
if emphasis:
|
| 173 |
+
ssml_parts.append(f'<emphasis level="{emphasis}">')
|
| 174 |
+
|
| 175 |
+
# Auto-insert breaks for natural speech
|
| 176 |
+
if breaks:
|
| 177 |
+
import re
|
| 178 |
+
# Add short breaks after commas, longer after periods
|
| 179 |
+
processed_text = re.sub(r'([,;:])\s*', r'\1<break time="200ms"/>', text)
|
| 180 |
+
processed_text = re.sub(r'([.!?])\s+', r'\1<break time="500ms"/>', processed_text)
|
| 181 |
+
ssml_parts.append(processed_text)
|
| 182 |
+
else:
|
| 183 |
+
ssml_parts.append(text)
|
| 184 |
+
|
| 185 |
+
if emphasis:
|
| 186 |
+
ssml_parts.append('</emphasis>')
|
| 187 |
+
|
| 188 |
+
ssml_parts.append('</prosody>')
|
| 189 |
+
ssml_parts.append('</voice>')
|
| 190 |
+
ssml_parts.append('</speak>')
|
| 191 |
+
|
| 192 |
+
return ''.join(ssml_parts)
|
| 193 |
+
|
| 194 |
+
async def synthesize_ssml(
|
| 195 |
+
self,
|
| 196 |
+
ssml_text: str,
|
| 197 |
+
voice: str = "en-US-AriaNeural",
|
| 198 |
+
) -> bytes:
|
| 199 |
+
"""
|
| 200 |
+
Synthesize speech from SSML markup.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
ssml_text: SSML-formatted text
|
| 204 |
+
voice: Voice name (for edge-tts communication)
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Audio bytes (MP3)
|
| 208 |
+
"""
|
| 209 |
+
logger.info(f"Synthesizing SSML with voice: {voice}")
|
| 210 |
+
|
| 211 |
+
# Edge TTS handles SSML natively
|
| 212 |
+
communicate = edge_tts.Communicate(ssml_text, voice)
|
| 213 |
+
|
| 214 |
+
audio_buffer = io.BytesIO()
|
| 215 |
+
async for chunk in communicate.stream():
|
| 216 |
+
if chunk["type"] == "audio":
|
| 217 |
+
audio_buffer.write(chunk["data"])
|
| 218 |
+
|
| 219 |
+
audio_buffer.seek(0)
|
| 220 |
+
return audio_buffer.read()
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
async def synthesize_stream(
|
| 224 |
+
self,
|
| 225 |
+
text: str,
|
| 226 |
+
voice: str = "en-US-AriaNeural",
|
| 227 |
+
rate: str = "+0%",
|
| 228 |
+
pitch: str = "+0Hz",
|
| 229 |
+
):
|
| 230 |
+
"""
|
| 231 |
+
Stream speech synthesis chunks.
|
| 232 |
+
|
| 233 |
+
Optimized to stream sentence-by-sentence to reduce TTFB (Time To First Byte),
|
| 234 |
+
avoiding full-text buffering issues.
|
| 235 |
+
"""
|
| 236 |
+
import re
|
| 237 |
+
|
| 238 |
+
# Split text into sentences to force incremental processing
|
| 239 |
+
# This regex matches sentences ending with . ! ? or end of string
|
| 240 |
+
# It keeps the proper punctuation.
|
| 241 |
+
sentences = re.findall(r'[^.!?]+(?:[.!?]+|$)', text)
|
| 242 |
+
if not sentences:
|
| 243 |
+
sentences = [text]
|
| 244 |
+
|
| 245 |
+
logger.info(f"Streaming {len(sentences)} sentences for low latency...")
|
| 246 |
+
|
| 247 |
+
for sentence in sentences:
|
| 248 |
+
if not sentence.strip():
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
communicate = edge_tts.Communicate(sentence, voice, rate=rate, pitch=pitch)
|
| 252 |
+
|
| 253 |
+
async for chunk in communicate.stream():
|
| 254 |
+
if chunk["type"] == "audio":
|
| 255 |
+
yield chunk["data"]
|
| 256 |
+
|
| 257 |
+
async def synthesize(
|
| 258 |
+
self,
|
| 259 |
+
text: str,
|
| 260 |
+
voice: str = "en-US-AriaNeural",
|
| 261 |
+
rate: str = "+0%",
|
| 262 |
+
pitch: str = "+0Hz",
|
| 263 |
+
) -> bytes:
|
| 264 |
+
"""
|
| 265 |
+
Synthesize speech from text
|
| 266 |
+
|
| 267 |
+
Args:
|
| 268 |
+
text: Text to synthesize
|
| 269 |
+
voice: Voice name (e.g., 'en-US-AriaNeural')
|
| 270 |
+
rate: Speaking rate adjustment (e.g., '+20%', '-10%')
|
| 271 |
+
pitch: Pitch adjustment (e.g., '+5Hz', '-10Hz')
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Audio content as bytes (MP3 format)
|
| 275 |
+
"""
|
| 276 |
+
# Reuse stream method to avoid duplication
|
| 277 |
+
audio_buffer = io.BytesIO()
|
| 278 |
+
async for chunk in self.synthesize_stream(text, voice, rate, pitch):
|
| 279 |
+
audio_buffer.write(chunk)
|
| 280 |
+
|
| 281 |
+
audio_buffer.seek(0)
|
| 282 |
+
return audio_buffer.read()
|
| 283 |
+
|
| 284 |
+
def synthesize_sync(
|
| 285 |
+
self,
|
| 286 |
+
text: str,
|
| 287 |
+
voice: str = "en-US-AriaNeural",
|
| 288 |
+
rate: str = "+0%",
|
| 289 |
+
pitch: str = "+0Hz",
|
| 290 |
+
) -> bytes:
|
| 291 |
+
"""Synchronous wrapper for synthesize"""
|
| 292 |
+
try:
|
| 293 |
+
loop = asyncio.get_event_loop()
|
| 294 |
+
except RuntimeError:
|
| 295 |
+
loop = asyncio.new_event_loop()
|
| 296 |
+
asyncio.set_event_loop(loop)
|
| 297 |
+
|
| 298 |
+
return loop.run_until_complete(self.synthesize(text, voice, rate, pitch))
|
| 299 |
+
|
| 300 |
+
async def synthesize_to_response(
|
| 301 |
+
self,
|
| 302 |
+
text: str,
|
| 303 |
+
voice: str = "en-US-AriaNeural",
|
| 304 |
+
speaking_rate: float = 1.0,
|
| 305 |
+
pitch: float = 0.0,
|
| 306 |
+
) -> Dict[str, Any]:
|
| 307 |
+
"""
|
| 308 |
+
Synthesize speech and return API-compatible response
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
text: Text to synthesize
|
| 312 |
+
voice: Voice name
|
| 313 |
+
speaking_rate: Rate multiplier (1.0 = normal, 1.5 = 50% faster)
|
| 314 |
+
pitch: Pitch adjustment in semitones (-20 to +20)
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
Dictionary with audio content and metadata
|
| 318 |
+
"""
|
| 319 |
+
import base64
|
| 320 |
+
import time
|
| 321 |
+
|
| 322 |
+
start_time = time.time()
|
| 323 |
+
|
| 324 |
+
# Convert rate/pitch to Edge TTS format
|
| 325 |
+
rate_percent = int((speaking_rate - 1.0) * 100)
|
| 326 |
+
rate_str = f"+{rate_percent}%" if rate_percent >= 0 else f"{rate_percent}%"
|
| 327 |
+
pitch_str = f"+{int(pitch)}Hz" if pitch >= 0 else f"{int(pitch)}Hz"
|
| 328 |
+
|
| 329 |
+
# Synthesize
|
| 330 |
+
audio_bytes = await self.synthesize(text, voice, rate_str, pitch_str)
|
| 331 |
+
|
| 332 |
+
processing_time = time.time() - start_time
|
| 333 |
+
|
| 334 |
+
# Estimate duration (~150 chars per second at normal speed)
|
| 335 |
+
estimated_duration = len(text) / 150 / speaking_rate
|
| 336 |
+
|
| 337 |
+
return {
|
| 338 |
+
"audio_content": base64.b64encode(audio_bytes).decode("utf-8"),
|
| 339 |
+
"encoding": "MP3",
|
| 340 |
+
"audio_size": len(audio_bytes),
|
| 341 |
+
"duration_estimate": estimated_duration,
|
| 342 |
+
"voice_used": voice,
|
| 343 |
+
"processing_time": processing_time,
|
| 344 |
+
"cached": False,
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# Singleton instance
|
| 349 |
+
_edge_tts_service: Optional[EdgeTTSService] = None
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
def get_edge_tts_service() -> EdgeTTSService:
|
| 353 |
+
"""Get or create the EdgeTTSService singleton"""
|
| 354 |
+
global _edge_tts_service
|
| 355 |
+
if _edge_tts_service is None:
|
| 356 |
+
_edge_tts_service = EdgeTTSService()
|
| 357 |
+
return _edge_tts_service
|
backend/app/services/emotion_service.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Emotion Analysis Service
|
| 3 |
+
Detects emotion from audio using Wav2Vec2 and text using NLP
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from typing import Dict, List, Any, Optional
|
| 12 |
+
|
| 13 |
+
from app.core.config import get_settings
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EmotionService:
|
| 19 |
+
"""
|
| 20 |
+
Service for Speech Emotion Recognition (SER).
|
| 21 |
+
Uses 'ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition'
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.model_name = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
|
| 26 |
+
self._model = None
|
| 27 |
+
self._processor = None
|
| 28 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
+
|
| 30 |
+
# Supported emotions in model's order
|
| 31 |
+
self.emotions = [
|
| 32 |
+
"angry", "calm", "disgust", "fearful",
|
| 33 |
+
"happy", "neutral", "sad", "surprised"
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
def _load_model(self):
|
| 37 |
+
"""Lazy load model to save RAM"""
|
| 38 |
+
if self._model is None:
|
| 39 |
+
try:
|
| 40 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
|
| 41 |
+
|
| 42 |
+
logger.info(f"🎭 Loading Emotion Model ({self.device})...")
|
| 43 |
+
self._processor = Wav2Vec2Processor.from_pretrained(self.model_name)
|
| 44 |
+
self._model = Wav2Vec2ForSequenceClassification.from_pretrained(self.model_name)
|
| 45 |
+
self._model.to(self.device)
|
| 46 |
+
logger.info("✅ Emotion Model loaded")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.error(f"Failed to load emotion model: {e}")
|
| 49 |
+
raise
|
| 50 |
+
|
| 51 |
+
def analyze_audio(self, audio_path: str) -> Dict[str, Any]:
|
| 52 |
+
"""
|
| 53 |
+
Analyze emotion of an entire audio file.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
audio_path: Path to audio file
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Dict with dominant emotion and probability distribution
|
| 60 |
+
"""
|
| 61 |
+
import librosa
|
| 62 |
+
|
| 63 |
+
self._load_model()
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# Load audio using librosa (16kHz required for Wav2Vec2)
|
| 67 |
+
# Duration limit: Analyze first 30s max for MVP to avoid OOM
|
| 68 |
+
# For full file, we should chunk it.
|
| 69 |
+
y, sr = librosa.load(audio_path, sr=16000, duration=60)
|
| 70 |
+
|
| 71 |
+
inputs = self._processor(y, sampling_rate=16000, return_tensors="pt", padding=True)
|
| 72 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 73 |
+
|
| 74 |
+
with torch.no_grad():
|
| 75 |
+
logits = self._model(**inputs).logits
|
| 76 |
+
|
| 77 |
+
# Get probabilities
|
| 78 |
+
probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
|
| 79 |
+
|
| 80 |
+
# Map to emotions
|
| 81 |
+
scores = {
|
| 82 |
+
self.emotions[i]: float(probs[i])
|
| 83 |
+
for i in range(len(self.emotions))
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Get dominant
|
| 87 |
+
dominant = max(scores, key=scores.get)
|
| 88 |
+
|
| 89 |
+
return {
|
| 90 |
+
"dominant_emotion": dominant,
|
| 91 |
+
"confidence": scores[dominant],
|
| 92 |
+
"distribution": scores
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Audio emotion analysis failed: {e}")
|
| 97 |
+
raise e
|
| 98 |
+
|
| 99 |
+
def analyze_audio_segment(self, audio_data: np.ndarray, sr: int = 16000) -> Dict[str, Any]:
|
| 100 |
+
"""
|
| 101 |
+
Analyze a raw numpy audio segment.
|
| 102 |
+
"""
|
| 103 |
+
self._load_model()
|
| 104 |
+
|
| 105 |
+
try:
|
| 106 |
+
inputs = self._processor(audio_data, sampling_rate=sr, return_tensors="pt", padding=True)
|
| 107 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 108 |
+
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
logits = self._model(**inputs).logits
|
| 111 |
+
|
| 112 |
+
probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
|
| 113 |
+
scores = {self.emotions[i]: float(probs[i]) for i in range(len(self.emotions))}
|
| 114 |
+
dominant = max(scores, key=scores.get)
|
| 115 |
+
|
| 116 |
+
return {
|
| 117 |
+
"emotion": dominant,
|
| 118 |
+
"score": scores[dominant]
|
| 119 |
+
}
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"Segment analysis failed: {e}")
|
| 122 |
+
return {"emotion": "neutral", "score": 0.0}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# Singleton
|
| 126 |
+
_emotion_service = None
|
| 127 |
+
|
| 128 |
+
def get_emotion_service() -> EmotionService:
|
| 129 |
+
global _emotion_service
|
| 130 |
+
if _emotion_service is None:
|
| 131 |
+
_emotion_service = EmotionService()
|
| 132 |
+
return _emotion_service
|