lordofgaming commited on
Commit
673435a
·
1 Parent(s): fbdfd83

Initial VoiceForge deployment (clean)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +39 -0
  2. README.md +20 -6
  3. backend/.env +16 -0
  4. backend/.flake8 +4 -0
  5. backend/Dockerfile +50 -0
  6. backend/app/__init__.py +3 -0
  7. backend/app/api/__init__.py +3 -0
  8. backend/app/api/routes/__init__.py +35 -0
  9. backend/app/api/routes/analysis.py +60 -0
  10. backend/app/api/routes/audio.py +100 -0
  11. backend/app/api/routes/auth.py +116 -0
  12. backend/app/api/routes/batch.py +204 -0
  13. backend/app/api/routes/cloning.py +81 -0
  14. backend/app/api/routes/health.py +93 -0
  15. backend/app/api/routes/s2s.py +45 -0
  16. backend/app/api/routes/sign.py +164 -0
  17. backend/app/api/routes/sign_bridge.py +63 -0
  18. backend/app/api/routes/stt.py +489 -0
  19. backend/app/api/routes/transcripts.py +200 -0
  20. backend/app/api/routes/translation.py +261 -0
  21. backend/app/api/routes/tts.py +245 -0
  22. backend/app/api/routes/ws.py +208 -0
  23. backend/app/core/__init__.py +7 -0
  24. backend/app/core/config.py +108 -0
  25. backend/app/core/limiter.py +27 -0
  26. backend/app/core/middleware.py +70 -0
  27. backend/app/core/request_size_middleware.py +91 -0
  28. backend/app/core/security.py +113 -0
  29. backend/app/core/security_encryption.py +107 -0
  30. backend/app/core/security_headers.py +37 -0
  31. backend/app/core/ws_security.py +181 -0
  32. backend/app/main.py +273 -0
  33. backend/app/models/__init__.py +19 -0
  34. backend/app/models/audio_file.py +47 -0
  35. backend/app/models/auth.py +36 -0
  36. backend/app/models/base.py +44 -0
  37. backend/app/models/sign_lstm.py +63 -0
  38. backend/app/models/transcript.py +67 -0
  39. backend/app/schemas/__init__.py +39 -0
  40. backend/app/schemas/stt.py +98 -0
  41. backend/app/schemas/transcript.py +69 -0
  42. backend/app/schemas/tts.py +67 -0
  43. backend/app/services/__init__.py +13 -0
  44. backend/app/services/audio_service.py +101 -0
  45. backend/app/services/batch_service.py +348 -0
  46. backend/app/services/cache_service.py +71 -0
  47. backend/app/services/clone_service.py +104 -0
  48. backend/app/services/diarization_service.py +338 -0
  49. backend/app/services/edge_tts_service.py +357 -0
  50. backend/app/services/emotion_service.py +132 -0
Dockerfile ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CPU-only base for free-tier HuggingFace Spaces
2
+ FROM python:3.10-slim
3
+
4
+ # Set environment variables
5
+ ENV PYTHONUNBUFFERED=1 \
6
+ PYTHONDONTWRITEBYTECODE=1 \
7
+ DEBIAN_FRONTEND=noninteractive
8
+
9
+ # Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ python3-pip \
12
+ python3-dev \
13
+ ffmpeg \
14
+ libsndfile1 \
15
+ git \
16
+ supervisor \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Set working directory
20
+ WORKDIR /app
21
+
22
+ # Copy requirements and install
23
+ COPY deploy/huggingface/requirements.txt .
24
+ RUN pip3 install --no-cache-dir -r requirements.txt
25
+
26
+ # Copy the rest of the application
27
+ COPY . .
28
+
29
+ # Create directory for weights and logs
30
+ RUN mkdir -p backend/app/models/weights /var/log/supervisor
31
+
32
+ # Copy supervisor config
33
+ COPY deploy/huggingface/supervisord.conf /etc/supervisor/conf.d/supervisord.conf
34
+
35
+ # Expose ports (HF expects the app on 7860)
36
+ EXPOSE 7860
37
+
38
+ # Run supervisor to start both backend and frontend
39
+ CMD ["/usr/bin/supervisord", "-c", "/etc/supervisor/conf.d/supervisord.conf"]
README.md CHANGED
@@ -1,10 +1,24 @@
1
  ---
2
- title: Voiceforge
3
- emoji: 🏢
4
- colorFrom: red
5
- colorTo: indigo
6
  sdk: docker
7
- pinned: false
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: VoiceForge Universal
3
+ emoji: 🎙️
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: docker
7
+ pinned: true
8
+ license: mit
9
  ---
10
 
11
+ # VoiceForge: Universal Communication Platform
12
+
13
+ **Instant Speech-to-Speech | Signed Communication | Meeting Intelligence**
14
+
15
+ Powered by:
16
+ - **Whisper** (STT)
17
+ - **SeamlessM4T** (S2S)
18
+ - **MediaPipe** (Sign Recognition)
19
+ - **Edge TTS** (Synthesis)
20
+
21
+ ## 🚀 Usage
22
+ 1. Click the "Speech-to-Speech" tab to translate voice instantly.
23
+ 2. Use "Signed Communication" to visualize ASL.
24
+ 3. Upload meetings for instant minutes.
backend/.env ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VoiceForge Backend Environment Variables
2
+
3
+ # Database
4
+ DATABASE_URL=sqlite:///./voiceforge.db
5
+
6
+ # Hugging Face Token (for Speaker Diarization)
7
+ # Get your token at: https://huggingface.co/settings/tokens
8
+ # See docs/HF_TOKEN_ROTATION.md for setup instructions
9
+ HF_TOKEN=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
10
+
11
+ # Encryption Key (auto-generated in dev, REQUIRED in production)
12
+ # Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())"
13
+ # ENCRYPTION_KEY=
14
+
15
+ # Environment (development | production)
16
+ ENVIRONMENT=development
backend/.flake8 ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 120
3
+ extend-ignore = E203
4
+ exclude = .git,__pycache__,docs/source/conf.py,old,build,dist,venv
backend/Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Build Stage
2
+ FROM python:3.10-slim as builder
3
+
4
+ WORKDIR /app
5
+
6
+ # Set environment variables
7
+ ENV PYTHONDONTWRITEBYTECODE 1
8
+ ENV PYTHONUNBUFFERED 1
9
+
10
+ # Install system dependencies required for building python packages
11
+ # ffmpeg is needed for audio processing
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ gcc \
14
+ ffmpeg \
15
+ && rm -rf /var/lib/apt/lists/*
16
+
17
+ # Install python dependencies
18
+ COPY requirements.txt .
19
+ RUN pip wheel --no-cache-dir --no-deps --wheel-dir /app/wheels -r requirements.txt
20
+
21
+
22
+ # Final Stage
23
+ FROM python:3.10-slim
24
+
25
+ WORKDIR /app
26
+
27
+ # Install runtime dependencies (ffmpeg)
28
+ RUN apt-get update && apt-get install -y --no-install-recommends \
29
+ ffmpeg \
30
+ && rm -rf /var/lib/apt/lists/*
31
+
32
+ # Copy wheels from builder
33
+ COPY --from=builder /app/wheels /wheels
34
+ COPY --from=builder /app/requirements.txt .
35
+
36
+ # Install dependencies from wheels
37
+ RUN pip install --no-cache /wheels/*
38
+
39
+ # Copy application code
40
+ COPY . .
41
+
42
+ # Create a non-root user
43
+ RUN addgroup --system app && adduser --system --group app
44
+ USER app
45
+
46
+ # Expose port
47
+ EXPOSE 8000
48
+
49
+ # Run commands
50
+ CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
backend/app/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ VoiceForge Backend Package
3
+ """
backend/app/api/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ VoiceForge API Package
3
+ """
backend/app/api/routes/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceForge API Routes Package
3
+ """
4
+
5
+ from .stt import router as stt_router
6
+ from .tts import router as tts_router
7
+ from .health import router as health_router
8
+ from .transcripts import router as transcripts_router
9
+ from .ws import router as ws_router
10
+ from .translation import router as translation_router
11
+ from .batch import router as batch_router
12
+ from .analysis import router as analysis_router
13
+ from .audio import router as audio_router
14
+ from .cloning import router as cloning_router
15
+ from .sign import router as sign_router
16
+ from .auth import router as auth_router
17
+ from .s2s import router as s2s_router
18
+ from .sign_bridge import router as sign_bridge_router
19
+
20
+ __all__ = [
21
+ "stt_router",
22
+ "tts_router",
23
+ "health_router",
24
+ "transcripts_router",
25
+ "ws_router",
26
+ "translation_router",
27
+ "batch_router",
28
+ "analysis_router",
29
+ "audio_router",
30
+ "cloning_router",
31
+ "sign_router",
32
+ "auth_router",
33
+ "s2s_router",
34
+ "sign_bridge_router",
35
+ ]
backend/app/api/routes/analysis.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analysis API Routes
3
+ Endpoints for Emotion and Sentiment Analysis
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
7
+ from typing import Dict, Any
8
+ import logging
9
+ import os
10
+ import shutil
11
+ import tempfile
12
+
13
+ from app.services.emotion_service import get_emotion_service
14
+ from app.services.nlp_service import get_nlp_service
15
+
16
+ logger = logging.getLogger(__name__)
17
+ router = APIRouter(prefix="/analysis", tags=["Analysis"])
18
+
19
+
20
+ @router.post("/emotion/audio")
21
+ async def analyze_audio_emotion(
22
+ file: UploadFile = File(..., description="Audio file to analyze"),
23
+ ):
24
+ """
25
+ Analyze emotions in an audio file using Wav2Vec2.
26
+ Returns dominant emotion and probability distribution.
27
+ """
28
+ service = get_emotion_service()
29
+
30
+ # Save to temp file
31
+ suffix = os.path.splitext(file.filename)[1] or ".wav"
32
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
33
+ shutil.copyfileobj(file.file, tmp)
34
+ tmp_path = tmp.name
35
+
36
+ try:
37
+ result = service.analyze_audio(tmp_path)
38
+ return result
39
+ except Exception as e:
40
+ logger.error(f"Emotion analysis failed: {e}")
41
+ raise HTTPException(status_code=500, detail=str(e))
42
+ finally:
43
+ try:
44
+ os.unlink(tmp_path)
45
+ except:
46
+ pass
47
+
48
+
49
+ @router.post("/sentiment/text")
50
+ async def analyze_text_sentiment(
51
+ text: str = Form(..., description="Text to analyze"),
52
+ ):
53
+ """
54
+ Analyze text sentiment (polarity and subjectivity).
55
+ """
56
+ nlp_service = get_nlp_service()
57
+ try:
58
+ return nlp_service.analyze_sentiment(text)
59
+ except Exception as e:
60
+ raise HTTPException(status_code=500, detail=str(e))
backend/app/api/routes/audio.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Editing API Routes
3
+ """
4
+
5
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
6
+ from fastapi.responses import FileResponse
7
+ from typing import List, Optional
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+ import uuid
12
+
13
+ from app.services.audio_service import get_audio_service, AudioService
14
+
15
+ router = APIRouter(prefix="/audio", tags=["Audio Studio"])
16
+
17
+ @router.post("/trim")
18
+ async def trim_audio(
19
+ file: UploadFile = File(..., description="Audio file"),
20
+ start_sec: float = Form(..., description="Start time in seconds"),
21
+ end_sec: float = Form(..., description="End time in seconds"),
22
+ service: AudioService = Depends(get_audio_service)
23
+ ):
24
+ """Trim an audio file"""
25
+ suffix = os.path.splitext(file.filename)[1] or ".mp3"
26
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
27
+ shutil.copyfileobj(file.file, tmp)
28
+ tmp_path = tmp.name
29
+
30
+ try:
31
+ output_path = tmp_path.replace(suffix, f"_trimmed{suffix}")
32
+ service.trim_audio(tmp_path, int(start_sec * 1000), int(end_sec * 1000), output_path)
33
+
34
+ return FileResponse(
35
+ output_path,
36
+ filename=f"trimmed_{file.filename}",
37
+ background=None # Let FastAPI handle cleanup? No, we need custom cleanup or use background task
38
+ )
39
+ except Exception as e:
40
+ raise HTTPException(status_code=500, detail=str(e))
41
+ # Note: Temp files might persist. In prod, use a cleanup task.
42
+
43
+ @router.post("/merge")
44
+ async def merge_audio(
45
+ files: List[UploadFile] = File(..., description="Files to merge"),
46
+ format: str = Form("mp3", description="Output format"),
47
+ service: AudioService = Depends(get_audio_service)
48
+ ):
49
+ """Merge multiple audio files"""
50
+ temp_files = []
51
+ try:
52
+ for file in files:
53
+ suffix = os.path.splitext(file.filename)[1] or ".mp3"
54
+ tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
55
+ content = await file.read()
56
+ tmp.write(content)
57
+ tmp.close()
58
+ temp_files.append(tmp.name)
59
+
60
+ output_filename = f"merged_{uuid.uuid4()}.{format}"
61
+ output_path = os.path.join(tempfile.gettempdir(), output_filename)
62
+
63
+ service.merge_audio(temp_files, output_path)
64
+
65
+ return FileResponse(output_path, filename=output_filename)
66
+
67
+ except Exception as e:
68
+ raise HTTPException(status_code=500, detail=str(e))
69
+ finally:
70
+ for p in temp_files:
71
+ try:
72
+ os.unlink(p)
73
+ except:
74
+ pass
75
+
76
+ @router.post("/convert")
77
+ async def convert_audio(
78
+ file: UploadFile = File(..., description="Audio file"),
79
+ target_format: str = Form(..., description="Target format (mp3, wav, flac, ogg)"),
80
+ service: AudioService = Depends(get_audio_service)
81
+ ):
82
+ """Convert audio format"""
83
+ suffix = os.path.splitext(file.filename)[1] or ".wav"
84
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
85
+ shutil.copyfileobj(file.file, tmp)
86
+ tmp_path = tmp.name
87
+
88
+ try:
89
+ output_path = service.convert_format(tmp_path, target_format)
90
+ return FileResponse(
91
+ output_path,
92
+ filename=f"{os.path.splitext(file.filename)[0]}.{target_format}"
93
+ )
94
+ except Exception as e:
95
+ raise HTTPException(status_code=500, detail=str(e))
96
+ finally:
97
+ try:
98
+ os.unlink(tmp_path)
99
+ except:
100
+ pass
backend/app/api/routes/auth.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ from typing import List
3
+ from pydantic import BaseModel
4
+ import secrets
5
+ from fastapi import APIRouter, Depends, HTTPException, status
6
+ from fastapi.security import OAuth2PasswordRequestForm
7
+ from sqlalchemy.orm import Session
8
+
9
+ from ...core.security import (
10
+ create_access_token,
11
+ get_password_hash,
12
+ verify_password,
13
+ get_current_active_user,
14
+ ACCESS_TOKEN_EXPIRE_MINUTES
15
+ )
16
+ from ...models import get_db, User, ApiKey
17
+ from ...core.limiter import limiter
18
+ from fastapi import APIRouter, Depends, HTTPException, status, Request
19
+
20
+ router = APIRouter(prefix="/auth", tags=["Authentication"])
21
+
22
+ # --- Schemas ---
23
+ class Token(BaseModel):
24
+ access_token: str
25
+ token_type: str
26
+
27
+ class UserCreate(BaseModel):
28
+ email: str
29
+ password: str
30
+ full_name: str = None
31
+
32
+ class UserOut(BaseModel):
33
+ id: int
34
+ email: str
35
+ full_name: str = None
36
+ is_active: bool
37
+
38
+ class Config:
39
+ orm_mode = True
40
+
41
+ class ApiKeyCreate(BaseModel):
42
+ name: str
43
+
44
+ class ApiKeyOut(BaseModel):
45
+ key: str
46
+ name: str
47
+ created_at: datetime
48
+
49
+ class Config:
50
+ orm_mode = True
51
+
52
+
53
+ # --- Endpoints ---
54
+
55
+ @router.post("/register", response_model=UserOut)
56
+ @limiter.limit("5/minute")
57
+ async def register(request: Request, user_in: UserCreate, db: Session = Depends(get_db)):
58
+ """Register a new user"""
59
+ existing_user = db.query(User).filter(User.email == user_in.email).first()
60
+ if existing_user:
61
+ raise HTTPException(status_code=400, detail="Email already registered")
62
+
63
+ hashed_password = get_password_hash(user_in.password)
64
+ new_user = User(
65
+ email=user_in.email,
66
+ hashed_password=hashed_password,
67
+ full_name=user_in.full_name
68
+ )
69
+ db.add(new_user)
70
+ db.commit()
71
+ db.refresh(new_user)
72
+ return new_user
73
+
74
+ @router.post("/login", response_model=Token)
75
+ @limiter.limit("5/minute")
76
+ async def login(request: Request, form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)):
77
+ """Login to get access token"""
78
+ user = db.query(User).filter(User.email == form_data.username).first()
79
+ if not user or not verify_password(form_data.password, user.hashed_password):
80
+ raise HTTPException(
81
+ status_code=status.HTTP_401_UNAUTHORIZED,
82
+ detail="Incorrect email or password",
83
+ headers={"WWW-Authenticate": "Bearer"},
84
+ )
85
+
86
+ access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
87
+ access_token = create_access_token(
88
+ subject=user.id, expires_delta=access_token_expires
89
+ )
90
+ return {"access_token": access_token, "token_type": "bearer"}
91
+
92
+ @router.post("/api-keys", response_model=ApiKeyOut)
93
+ async def create_api_key(
94
+ key_in: ApiKeyCreate,
95
+ current_user: User = Depends(get_current_active_user),
96
+ db: Session = Depends(get_db)
97
+ ):
98
+ """Generate a new API key for the current user"""
99
+ # Generate secure 32-char key
100
+ raw_key = secrets.token_urlsafe(32)
101
+ api_key_str = f"vf_{raw_key}" # Prefix for identification
102
+
103
+ new_key = ApiKey(
104
+ key=api_key_str,
105
+ name=key_in.name,
106
+ user_id=current_user.id
107
+ )
108
+ db.add(new_key)
109
+ db.commit()
110
+ db.refresh(new_key)
111
+ return new_key
112
+
113
+ @router.get("/me", response_model=UserOut)
114
+ async def read_users_me(current_user: User = Depends(get_current_active_user)):
115
+ """Get current user details"""
116
+ return current_user
backend/app/api/routes/batch.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch Processing API Routes
3
+ Endpoints for submitting and managing batch transcription jobs
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends, BackgroundTasks
7
+ from fastapi.responses import FileResponse
8
+ from pydantic import BaseModel, Field
9
+ from typing import List, Optional, Dict, Any
10
+ import logging
11
+ import shutil
12
+ import os
13
+ import tempfile
14
+ from pathlib import Path
15
+
16
+ from app.services.batch_service import get_batch_service
17
+
18
+ logger = logging.getLogger(__name__)
19
+ router = APIRouter(prefix="/batch", tags=["batch"])
20
+
21
+
22
+ # Request/Response Models
23
+ class BatchJobResponse(BaseModel):
24
+ """Batch job response model."""
25
+ job_id: str
26
+ status: str
27
+ progress: float
28
+ created_at: str
29
+ total_files: int
30
+ completed_files: int
31
+ failed_files: int
32
+ has_zip: bool
33
+ files: Optional[Dict[str, Any]] = None
34
+
35
+
36
+ # Endpoints
37
+ @router.post("/transcribe", response_model=BatchJobResponse)
38
+ async def create_batch_job(
39
+ background_tasks: BackgroundTasks,
40
+ files: List[UploadFile] = File(..., description="Audio files to transcribe"),
41
+ language: Optional[str] = Form(None, description="Language code (e.g., 'en', 'hi')"),
42
+ output_format: str = Form("txt", description="Output format (txt, srt)"),
43
+ ):
44
+ """
45
+ Submit a batch of audio files for transcription.
46
+
47
+ 1. Uploads multiple files
48
+ 2. Creates a batch job
49
+ 3. Starts processing in background
50
+
51
+ Args:
52
+ files: List of audio files
53
+ language: Optional language code
54
+ output_format: Output format (txt or srt)
55
+
56
+ Returns:
57
+ Created job details
58
+ """
59
+ if not files:
60
+ raise HTTPException(status_code=400, detail="No files provided")
61
+
62
+ if len(files) > 50:
63
+ raise HTTPException(status_code=400, detail="Maximum 50 files per batch")
64
+
65
+ try:
66
+ service = get_batch_service()
67
+
68
+ # Create temp files for processing
69
+ file_paths = {}
70
+ original_names = []
71
+
72
+ for file in files:
73
+ suffix = Path(file.filename).suffix or ".wav"
74
+ # Create a named temp file that persists until manually deleted
75
+ tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
76
+ content = await file.read()
77
+ tmp.write(content)
78
+ tmp.close()
79
+
80
+ file_paths[file.filename] = tmp.name
81
+ original_names.append(file.filename)
82
+
83
+ # Create job
84
+ job = service.create_job(
85
+ filenames=original_names,
86
+ options={
87
+ "language": language,
88
+ "output_format": output_format,
89
+ }
90
+ )
91
+
92
+ # Connect to Celery worker for processing
93
+ from app.workers.tasks import process_audio_file
94
+
95
+ # NOTE: For MVP batch service, we are currently keeping the simplified background_tasks approach
96
+ # because the 'process_audio_file' task defined in tasks.py is for individual files,
97
+ # whereas 'process_job' handles the whole batch logic (zipping etc).
98
+ # To fully migrate, we would need to refactor batch_service to span multiple tasks.
99
+ #
100
+ # For now, let's keep the background_task for the orchestrator, and have the orchestrator
101
+ # call the celery tasks for individual files?
102
+ # Actually, `service.process_job` currently runs synchronously in a background thread.
103
+ # We will leave as is for 3.1 step 1, but we CAN use Celery for the individual transcriptions.
104
+
105
+ # Start processing in background (Orchestrator runs in thread, calls expensive operations)
106
+ background_tasks.add_task(
107
+ service.process_job,
108
+ job_id=job.job_id,
109
+ file_paths=file_paths,
110
+ )
111
+
112
+ return job.to_dict()
113
+
114
+ except Exception as e:
115
+ # Cleanup any created temp files on error
116
+ for path in file_paths.values():
117
+ try:
118
+ os.unlink(path)
119
+ except:
120
+ pass
121
+ logger.error(f"Batch job creation failed: {e}")
122
+ raise HTTPException(status_code=500, detail=str(e))
123
+
124
+
125
+ @router.get("/jobs", response_model=List[BatchJobResponse])
126
+ async def list_jobs(limit: int = 10):
127
+ """
128
+ List recent batch jobs.
129
+
130
+ Args:
131
+ limit: Max number of jobs to return
132
+
133
+ Returns:
134
+ List of jobs
135
+ """
136
+ service = get_batch_service()
137
+ jobs = service.list_jobs(limit)
138
+ return [job.to_dict() for job in jobs]
139
+
140
+
141
+ @router.get("/{job_id}", response_model=BatchJobResponse)
142
+ async def get_job_status(job_id: str):
143
+ """
144
+ Get status of a specific batch job.
145
+
146
+ Args:
147
+ job_id: Job ID
148
+
149
+ Returns:
150
+ Job details and progress
151
+ """
152
+ service = get_batch_service()
153
+ job = service.get_job(job_id)
154
+
155
+ if not job:
156
+ raise HTTPException(status_code=404, detail="Job not found")
157
+
158
+ return job.to_dict()
159
+
160
+
161
+ @router.get("/{job_id}/download")
162
+ async def download_results(job_id: str):
163
+ """
164
+ Download batch job results as ZIP.
165
+
166
+ Args:
167
+ job_id: Job ID
168
+
169
+ Returns:
170
+ ZIP file download
171
+ """
172
+ service = get_batch_service()
173
+ zip_path = service.get_zip_path(job_id)
174
+
175
+ if not zip_path:
176
+ raise HTTPException(status_code=404, detail="Results not available (job may be processing or failed)")
177
+
178
+ return FileResponse(
179
+ path=zip_path,
180
+ filename=f"batch_{job_id}_results.zip",
181
+ media_type="application/zip",
182
+ )
183
+
184
+
185
+ @router.delete("/{job_id}")
186
+ async def delete_job(job_id: str):
187
+ """
188
+ Delete a batch job and cleanup files.
189
+
190
+ Args:
191
+ job_id: Job ID
192
+ """
193
+ service = get_batch_service()
194
+
195
+ # Try to cancel first if running
196
+ service.cancel_job(job_id)
197
+
198
+ # Delete data
199
+ success = service.delete_job(job_id)
200
+
201
+ if not success:
202
+ raise HTTPException(status_code=404, detail="Job not found")
203
+
204
+ return {"status": "deleted", "job_id": job_id}
backend/app/api/routes/cloning.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Voice Cloning API Routes
3
+ """
4
+
5
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form, Depends
6
+ from fastapi.responses import FileResponse
7
+ from typing import List, Optional
8
+ import os
9
+ import shutil
10
+ import tempfile
11
+ import uuid
12
+
13
+ from app.services.clone_service import get_clone_service, CloneService
14
+
15
+ router = APIRouter(prefix="/clone", tags=["Voice Cloning"])
16
+
17
+ @router.post("/synthesize")
18
+ async def clone_synthesize(
19
+ text: str = Form(..., description="Text to speak"),
20
+ language: str = Form("en", description="Language code (en, es, fr, de, etc.)"),
21
+ files: List[UploadFile] = File(..., description="Reference audio samples (1-3 files, 3-10s each recommended)"),
22
+ service: CloneService = Depends(get_clone_service)
23
+ ):
24
+ """
25
+ Clone a voice from reference audio samples.
26
+
27
+ Uses Coqui XTTS v2.
28
+ WARNING: Heavy operation. May take 5-20 seconds depending on GPU.
29
+ """
30
+
31
+ # Validation
32
+ if not files:
33
+ raise HTTPException(status_code=400, detail="At least one reference audio file is required")
34
+
35
+ temp_files = []
36
+
37
+ try:
38
+ # Save reference files
39
+ for file in files:
40
+ suffix = os.path.splitext(file.filename)[1] or ".wav"
41
+ tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
42
+ content = await file.read()
43
+ tmp.write(content)
44
+ tmp.close()
45
+ temp_files.append(tmp.name)
46
+
47
+ # Generate output path
48
+ output_filename = f"cloned_{uuid.uuid4()}.wav"
49
+ output_path = os.path.join(tempfile.gettempdir(), output_filename)
50
+
51
+ # Synthesize
52
+ service.clone_voice(
53
+ text=text,
54
+ speaker_wav_paths=temp_files,
55
+ language=language,
56
+ output_path=output_path
57
+ )
58
+
59
+ return FileResponse(
60
+ output_path,
61
+ filename="cloned_speech.wav",
62
+ media_type="audio/wav"
63
+ )
64
+
65
+ except ImportError:
66
+ raise HTTPException(status_code=503, detail="Voice Cloning service not available (TTS library missing)")
67
+ except Exception as e:
68
+ raise HTTPException(status_code=500, detail=str(e))
69
+
70
+ finally:
71
+ # Cleanup input files
72
+ for p in temp_files:
73
+ try:
74
+ os.unlink(p)
75
+ except:
76
+ pass
77
+ # Note: Output file cleanup needs management in prod (background task or stream)
78
+
79
+ @router.get("/languages")
80
+ def get_languages(service: CloneService = Depends(get_clone_service)):
81
+ return {"languages": service.get_supported_languages()}
backend/app/api/routes/health.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Health Check Router
3
+ """
4
+
5
+ from fastapi import APIRouter
6
+
7
+ router = APIRouter(prefix="/health", tags=["Health"])
8
+
9
+
10
+ @router.get("")
11
+ @router.get("/")
12
+ async def health_check():
13
+ """Basic health check endpoint"""
14
+ return {
15
+ "status": "healthy",
16
+ "service": "voiceforge-api",
17
+ "version": "1.0.0",
18
+ }
19
+
20
+
21
+ @router.get("/ready")
22
+ async def readiness_check():
23
+ """Readiness check - verifies all dependencies are available"""
24
+ # TODO: Check database, Redis, Google Cloud connectivity
25
+ return {
26
+ "status": "ready",
27
+ "checks": {
28
+ "database": "ok",
29
+ "redis": "ok",
30
+ "google_cloud": "ok",
31
+ }
32
+ }
33
+
34
+
35
+ @router.get("/memory")
36
+ async def memory_status():
37
+ """Get current memory usage and loaded models"""
38
+ from ...services.whisper_stt_service import (
39
+ _whisper_models,
40
+ _model_last_used,
41
+ get_memory_usage_mb
42
+ )
43
+ import time
44
+
45
+ current_time = time.time()
46
+ models_info = {}
47
+
48
+ for name in _whisper_models.keys():
49
+ last_used = _model_last_used.get(name, 0)
50
+ idle_seconds = current_time - last_used if last_used else 0
51
+ models_info[name] = {
52
+ "loaded": True,
53
+ "idle_seconds": round(idle_seconds, 1)
54
+ }
55
+
56
+ return {
57
+ "memory_mb": round(get_memory_usage_mb(), 1),
58
+ "loaded_models": list(_whisper_models.keys()),
59
+ "models_detail": models_info
60
+ }
61
+
62
+
63
+ @router.post("/memory/cleanup")
64
+ async def cleanup_memory():
65
+ """Unload idle models to free memory"""
66
+ from ...services.whisper_stt_service import cleanup_idle_models, get_memory_usage_mb
67
+
68
+ before = get_memory_usage_mb()
69
+ cleanup_idle_models()
70
+ after = get_memory_usage_mb()
71
+
72
+ return {
73
+ "memory_before_mb": round(before, 1),
74
+ "memory_after_mb": round(after, 1),
75
+ "freed_mb": round(before - after, 1)
76
+ }
77
+
78
+
79
+ @router.post("/memory/unload-all")
80
+ async def unload_all():
81
+ """Unload ALL models to free maximum memory"""
82
+ from ...services.whisper_stt_service import unload_all_models, get_memory_usage_mb
83
+
84
+ before = get_memory_usage_mb()
85
+ unloaded = unload_all_models()
86
+ after = get_memory_usage_mb()
87
+
88
+ return {
89
+ "unloaded_models": unloaded,
90
+ "memory_before_mb": round(before, 1),
91
+ "memory_after_mb": round(after, 1),
92
+ "freed_mb": round(before - after, 1)
93
+ }
backend/app/api/routes/s2s.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-Speech API Router
3
+ """
4
+
5
+ import logging
6
+ from typing import Optional
7
+ from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException
8
+
9
+ from app.services.speech_bridge_service import get_bridge_service, SpeechBridgeService
10
+ from app.core.config import get_settings
11
+
12
+ logger = logging.getLogger(__name__)
13
+ router = APIRouter(prefix="/s2s", tags=["Speech-to-Speech"])
14
+ settings = get_settings()
15
+
16
+ @router.post("/process")
17
+ async def process_speech_to_speech(
18
+ file: UploadFile = File(..., description="Audio file to process"),
19
+ source_lang: str = Form("en", description="Source language code (e.g. 'en', 'hi')"),
20
+ target_lang: str = Form("es", description="Target language code (e.g. 'es', 'fr')"),
21
+ voice_id: Optional[str] = Form(None, description="Target TTS Voice ID"),
22
+ bridge_service: SpeechBridgeService = Depends(get_bridge_service)
23
+ ):
24
+ """
25
+ Process audio: Speech -> Text -> Translation -> Speech
26
+ """
27
+ try:
28
+ # Read audio file
29
+ audio_bytes = await file.read()
30
+
31
+ result = await bridge_service.process_speech_to_speech(
32
+ audio_bytes=audio_bytes,
33
+ source_lang=source_lang,
34
+ target_lang=target_lang,
35
+ voice_id=voice_id
36
+ )
37
+
38
+ if "error" in result:
39
+ raise HTTPException(status_code=400, detail=result["error"])
40
+
41
+ return result
42
+
43
+ except Exception as e:
44
+ logger.error(f"S2S API Error: {e}")
45
+ raise HTTPException(status_code=500, detail=str(e))
backend/app/api/routes/sign.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sign Language API Routes
3
+ Provides WebSocket and REST endpoints for ASL recognition.
4
+ """
5
+
6
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, UploadFile, File, HTTPException
7
+ from fastapi.responses import JSONResponse
8
+ import numpy as np
9
+ import base64
10
+ import cv2
11
+ import logging
12
+ from typing import List
13
+
14
+ from ...services.sign_recognition_service import get_sign_service, SignPrediction
15
+ from ...services.sign_avatar_service import get_avatar_service
16
+ from pydantic import BaseModel
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ router = APIRouter(prefix="/sign", tags=["Sign Language"])
21
+
22
+ class TextToSignRequest(BaseModel):
23
+ text: str
24
+
25
+
26
+ @router.get("/health")
27
+ async def sign_health():
28
+ """Check if sign recognition service is available"""
29
+ try:
30
+ service = get_sign_service()
31
+ return {"status": "ready", "service": "SignRecognitionService"}
32
+ except Exception as e:
33
+ return {"status": "error", "message": str(e)}
34
+
35
+
36
+ @router.post("/recognize")
37
+ async def recognize_sign(file: UploadFile = File(..., description="Image of hand sign")):
38
+ """
39
+ Recognize ASL letter from a single image.
40
+
41
+ Upload an image containing a hand sign to get the predicted letter.
42
+ """
43
+ try:
44
+ # Read image
45
+ contents = await file.read()
46
+ nparr = np.frombuffer(contents, np.uint8)
47
+ image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
48
+
49
+ if image is None:
50
+ raise HTTPException(status_code=400, detail="Invalid image file")
51
+
52
+ # Get predictions
53
+ service = get_sign_service()
54
+ predictions = service.process_frame(image)
55
+
56
+ if not predictions:
57
+ return JSONResponse({
58
+ "success": True,
59
+ "predictions": [],
60
+ "message": "No hands detected in image"
61
+ })
62
+
63
+ return JSONResponse({
64
+ "success": True,
65
+ "predictions": [
66
+ {
67
+ "letter": p.letter,
68
+ "confidence": p.confidence
69
+ }
70
+ for p in predictions
71
+ ]
72
+ })
73
+
74
+ except Exception as e:
75
+ logger.error(f"Sign recognition error: {e}")
76
+ raise HTTPException(status_code=500, detail=str(e))
77
+
78
+
79
+ @router.websocket("/live")
80
+ async def sign_websocket(websocket: WebSocket):
81
+ """
82
+ WebSocket endpoint for real-time sign language recognition.
83
+
84
+ Client sends base64-encoded JPEG frames, server responds with predictions.
85
+
86
+ Protocol:
87
+ - Client sends: {"frame": "<base64 jpeg>"}
88
+ - Server sends: {"predictions": [{"letter": "A", "confidence": 0.8}]}
89
+ """
90
+ await websocket.accept()
91
+ service = get_sign_service()
92
+
93
+ logger.info("Sign language WebSocket connected")
94
+
95
+ try:
96
+ while True:
97
+ # Receive frame from client
98
+ data = await websocket.receive_json()
99
+
100
+ if "frame" not in data:
101
+ await websocket.send_json({"error": "Missing 'frame' field"})
102
+ continue
103
+
104
+ # Decode base64 image
105
+ try:
106
+ frame_data = base64.b64decode(data["frame"])
107
+ nparr = np.frombuffer(frame_data, np.uint8)
108
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
109
+
110
+ if frame is None:
111
+ await websocket.send_json({"error": "Invalid frame data"})
112
+ continue
113
+
114
+ except Exception as e:
115
+ await websocket.send_json({"error": f"Frame decode error: {e}"})
116
+ continue
117
+
118
+ # Process frame
119
+ predictions = service.process_frame(frame)
120
+
121
+ # Send results
122
+ await websocket.send_json({
123
+ "predictions": [
124
+ {
125
+ "letter": p.letter,
126
+ "confidence": round(p.confidence, 2)
127
+ }
128
+ for p in predictions
129
+ ]
130
+ })
131
+
132
+ except WebSocketDisconnect:
133
+ logger.info("Sign language WebSocket disconnected")
134
+ except Exception as e:
135
+ logger.error(f"WebSocket error: {e}")
136
+ await websocket.close(code=1011, reason=str(e))
137
+
138
+
139
+ @router.get("/alphabet")
140
+ async def get_alphabet():
141
+ """Get list of supported ASL letters"""
142
+ return {
143
+ "supported_letters": list("ABCDILUVWY5"), # Currently implemented
144
+ "note": "J and Z require motion tracking (coming soon)"
145
+ }
146
+
147
+
148
+ @router.post("/animate")
149
+ async def animate_text(request: TextToSignRequest):
150
+ """
151
+ Convert text to sign language animation sequence (Finger Spelling).
152
+ """
153
+ try:
154
+ service = get_avatar_service()
155
+ sequence = service.text_to_glosses(request.text)
156
+
157
+ return {
158
+ "success": True,
159
+ "sequence": sequence,
160
+ "count": len(sequence)
161
+ }
162
+ except Exception as e:
163
+ logger.error(f"Animation error: {e}")
164
+ raise HTTPException(status_code=500, detail=str(e))
backend/app/api/routes/sign_bridge.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sign-to-Speech API Router
3
+ """
4
+
5
+ import logging
6
+ from typing import Optional, Dict, Any
7
+ from fastapi import APIRouter, UploadFile, File, Form, Depends, HTTPException, Body
8
+ from pydantic import BaseModel
9
+
10
+ from app.services.sign_bridge_service import get_sign_bridge_service, SignBridgeService
11
+ from app.core.config import get_settings
12
+
13
+ logger = logging.getLogger(__name__)
14
+ router = APIRouter(prefix="/sign-bridge", tags=["Sign-to-Speech"])
15
+ settings = get_settings()
16
+
17
+ class SignTextRequest(BaseModel):
18
+ text: str
19
+ voice_id: Optional[str] = "en-US-AriaNeural"
20
+
21
+ @router.post("/speak")
22
+ async def speak_sign_text(
23
+ request: SignTextRequest,
24
+ bridge_service: SignBridgeService = Depends(get_sign_bridge_service)
25
+ ):
26
+ """
27
+ Speak text derived from Sign Language.
28
+ """
29
+ try:
30
+ result = await bridge_service.speak_text(
31
+ text=request.text,
32
+ voice_id=request.voice_id
33
+ )
34
+ if "error" in result:
35
+ raise HTTPException(status_code=400, detail=result["error"])
36
+ return result
37
+ except Exception as e:
38
+ logger.error(f"Sign-Bridge Speak Error: {e}")
39
+ raise HTTPException(status_code=500, detail=str(e))
40
+
41
+ @router.post("/process-frame")
42
+ async def process_frame(
43
+ file: UploadFile = File(...),
44
+ bridge_service: SignBridgeService = Depends(get_sign_bridge_service)
45
+ ):
46
+ """
47
+ Process a video frame for sign recognition (Backend-side).
48
+ Note: For real-time, client-side MediaPipe is preferred.
49
+ """
50
+ try:
51
+ # Read image
52
+ import numpy as np
53
+ import cv2
54
+
55
+ contents = await file.read()
56
+ nparr = np.frombuffer(contents, np.uint8)
57
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
58
+
59
+ result = await bridge_service.process_sign_frame(img)
60
+ return result
61
+ except Exception as e:
62
+ logger.error(f"Sign-Bridge Frame Error: {e}")
63
+ raise HTTPException(status_code=500, detail=str(e))
backend/app/api/routes/stt.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-Text API Router
3
+ """
4
+
5
+ import logging
6
+ from datetime import datetime
7
+ from typing import Optional, List
8
+
9
+ from fastapi import APIRouter, UploadFile, File, Form, HTTPException, Depends, Request
10
+ from fastapi.responses import JSONResponse
11
+
12
+ from ...core.limiter import limiter
13
+
14
+ from ...services.stt_service import get_stt_service, STTService
15
+ from ...services.file_service import get_file_service, FileService
16
+ from ...schemas.stt import (
17
+ TranscriptionResponse,
18
+ TranscriptionRequest,
19
+ LanguageInfo,
20
+ LanguageListResponse,
21
+ )
22
+ from ...core.config import get_settings
23
+ from sqlalchemy.orm import Session
24
+ from app.models import get_db, AudioFile, Transcript
25
+ from ...workers.tasks import process_audio_file
26
+ from celery.result import AsyncResult
27
+ from ...schemas.stt import (
28
+ TranscriptionResponse,
29
+ TranscriptionRequest,
30
+ LanguageInfo,
31
+ LanguageListResponse,
32
+ AsyncTranscriptionResponse,
33
+ TaskStatusResponse,
34
+ )
35
+
36
+
37
+ logger = logging.getLogger(__name__)
38
+ router = APIRouter(prefix="/stt", tags=["Speech-to-Text"])
39
+ settings = get_settings()
40
+
41
+
42
+ @router.get("/languages", response_model=LanguageListResponse)
43
+ async def get_supported_languages(
44
+ stt_service: STTService = Depends(get_stt_service),
45
+ ):
46
+ """
47
+ Get list of supported languages for speech-to-text
48
+ """
49
+ languages = stt_service.get_supported_languages()
50
+ return LanguageListResponse(
51
+ languages=languages,
52
+ total=len(languages),
53
+ )
54
+
55
+
56
+ @router.post("/upload", response_model=TranscriptionResponse)
57
+ @limiter.limit("10/minute")
58
+ async def transcribe_upload(
59
+ request: Request,
60
+ file: UploadFile = File(..., description="Audio file to transcribe"),
61
+ language: str = Form(default="en-US", description="Language code"),
62
+ enable_punctuation: bool = Form(default=True, description="Enable automatic punctuation"),
63
+ enable_word_timestamps: bool = Form(default=True, description="Include word-level timestamps"),
64
+ enable_diarization: bool = Form(default=False, description="Enable speaker diarization"),
65
+ speaker_count: Optional[int] = Form(default=None, description="Expected number of speakers"),
66
+ prompt: Optional[str] = Form(None, description="Custom vocabulary/keywords (e.g. 'VoiceForge, PyTorch')"),
67
+ stt_service: STTService = Depends(get_stt_service),
68
+ file_service: FileService = Depends(get_file_service),
69
+ db: Session = Depends(get_db),
70
+
71
+ ):
72
+ """
73
+ Transcribe an uploaded audio file
74
+
75
+ Supports: WAV, MP3, M4A, FLAC, OGG, WebM
76
+
77
+ For files longer than 1 minute, consider using the async endpoint.
78
+ """
79
+ # Validate file type
80
+ if not file.filename:
81
+ raise HTTPException(status_code=400, detail="No filename provided")
82
+
83
+ ext = file.filename.split(".")[-1].lower()
84
+ if ext not in settings.supported_audio_formats_list:
85
+ raise HTTPException(
86
+ status_code=400,
87
+ detail=f"Unsupported format: {ext}. Supported: {', '.join(settings.supported_audio_formats_list)}"
88
+ )
89
+
90
+ # Validate language
91
+ if language not in settings.supported_languages_list:
92
+ raise HTTPException(
93
+ status_code=400,
94
+ detail=f"Unsupported language: {language}. Supported: {', '.join(settings.supported_languages_list)}"
95
+ )
96
+
97
+ try:
98
+ # Read file content
99
+ content = await file.read()
100
+
101
+ # Save to storage
102
+ storage_path, metadata = file_service.save_upload(
103
+ file_content=content,
104
+ original_filename=file.filename,
105
+ )
106
+
107
+ logger.info(f"Processing transcription for {file.filename} ({len(content)} bytes)")
108
+
109
+ # Perform transcription
110
+ result = stt_service.transcribe_file(
111
+ audio_path=storage_path,
112
+ language=language,
113
+ enable_automatic_punctuation=enable_punctuation,
114
+ enable_word_time_offsets=enable_word_timestamps,
115
+ enable_speaker_diarization=enable_diarization,
116
+ diarization_speaker_count=speaker_count,
117
+ sample_rate=metadata.get("sample_rate"),
118
+ prompt=prompt, # Custom vocabulary
119
+ )
120
+
121
+ # Clean up temp file (optional - could keep for history)
122
+ # file_service.delete_file(storage_path)
123
+
124
+ # Save to database
125
+
126
+ try:
127
+ # 1. Create AudioFile record
128
+ audio_file = AudioFile(
129
+ storage_path=str(storage_path),
130
+ original_filename=file.filename,
131
+ duration=result.duration,
132
+ format=ext,
133
+ sample_rate=metadata.get("sample_rate"),
134
+ language=language,
135
+ detected_language=result.language,
136
+ status="done"
137
+ )
138
+ db.add(audio_file)
139
+ db.flush() # get ID
140
+
141
+ # 2. Create Transcript record
142
+ transcript = Transcript(
143
+ audio_file_id=audio_file.id,
144
+ raw_text=result.text,
145
+ processed_text=result.text, # initially same
146
+ segments=[s.model_dump() for s in result.segments] if result.segments else [],
147
+ language=result.language,
148
+ created_at=datetime.utcnow(),
149
+ )
150
+ db.add(transcript)
151
+ db.commit()
152
+ db.refresh(transcript)
153
+
154
+ # Return result with ID
155
+ response_data = result.model_dump()
156
+ response_data["id"] = transcript.id
157
+
158
+ # Explicitly validate to catch errors early
159
+ try:
160
+ return TranscriptionResponse(**response_data)
161
+ except Exception as e:
162
+ logger.error(f"Validation error for response: {e}")
163
+ logger.error(f"Response data: {response_data}")
164
+ raise HTTPException(status_code=500, detail=f"Response validation failed: {str(e)}")
165
+ # return response - removed undefined variable
166
+
167
+ except Exception as e:
168
+ logger.error(f"Failed to save to DB: {e}")
169
+ # Don't fail the request if DB save fails, just return result
170
+ # But in production we might want to ensure persistence
171
+ return result
172
+
173
+ except FileNotFoundError as e:
174
+ logger.error(f"File error: {e}")
175
+ raise HTTPException(status_code=404, detail=str(e))
176
+ except ValueError as e:
177
+ logger.error(f"Validation error: {e}")
178
+ raise HTTPException(status_code=400, detail=str(e))
179
+ except Exception as e:
180
+ logger.exception(f"Transcription failed: {e}")
181
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
182
+
183
+
184
+ @router.post("/upload/quality")
185
+ async def transcribe_quality(
186
+ file: UploadFile = File(..., description="Audio file to transcribe"),
187
+ language: str = Form(default="en-US", description="Language code"),
188
+ preprocess: bool = Form(default=False, description="Apply noise reduction (5-15% WER improvement)"),
189
+ prompt: Optional[str] = Form(None, description="Custom vocabulary/keywords"),
190
+ ):
191
+ """
192
+ High-quality transcription mode (optimized for accuracy).
193
+
194
+ Features:
195
+ - beam_size=5 for more accurate decoding (~40% fewer errors)
196
+ - condition_on_previous_text=False to reduce hallucinations
197
+ - Optional audio preprocessing for noisy environments
198
+
199
+ Trade-off: ~2x slower than standard mode
200
+ Best for: Important recordings, noisy audio, reduced error tolerance
201
+ """
202
+ from app.services.whisper_stt_service import get_whisper_stt_service
203
+ import tempfile
204
+ import os
205
+
206
+ # Validate file
207
+ if not file.filename:
208
+ raise HTTPException(status_code=400, detail="No filename provided")
209
+
210
+ try:
211
+ content = await file.read()
212
+
213
+ # Save to temp file
214
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
215
+ f.write(content)
216
+ temp_path = f.name
217
+
218
+ try:
219
+ stt_service = get_whisper_stt_service()
220
+ result = stt_service.transcribe_quality(
221
+ temp_path,
222
+ language=language,
223
+ preprocess=preprocess,
224
+ prompt=prompt,
225
+ )
226
+ return result
227
+ finally:
228
+ try:
229
+ os.unlink(temp_path)
230
+ except:
231
+ pass
232
+
233
+ except Exception as e:
234
+ logger.exception(f"Quality transcription failed: {e}")
235
+ raise HTTPException(status_code=500, detail=f"Transcription failed: {str(e)}")
236
+
237
+
238
+ @router.post("/upload/batch")
239
+ async def transcribe_batch(
240
+ files: List[UploadFile] = File(..., description="Multiple audio files to transcribe"),
241
+ language: str = Form(default="en-US", description="Language code"),
242
+ batch_size: int = Form(default=8, description="Batch size (8 optimal for CPU)"),
243
+ ):
244
+ """
245
+ Batch transcription for high throughput.
246
+
247
+ Uses BatchedInferencePipeline for 2-3x speedup on concurrent files.
248
+
249
+ Best for: Processing multiple files, API with high concurrency
250
+ """
251
+ from app.services.whisper_stt_service import get_whisper_stt_service
252
+ import tempfile
253
+ import os
254
+
255
+ if not files:
256
+ raise HTTPException(status_code=400, detail="No files provided")
257
+
258
+ results = []
259
+ stt_service = get_whisper_stt_service()
260
+
261
+ for file in files:
262
+ if not file.filename:
263
+ continue
264
+
265
+ try:
266
+ content = await file.read()
267
+
268
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
269
+ f.write(content)
270
+ temp_path = f.name
271
+
272
+ try:
273
+ result = stt_service.transcribe_batched(
274
+ temp_path,
275
+ language=language,
276
+ batch_size=batch_size,
277
+ )
278
+ result["filename"] = file.filename
279
+ results.append(result)
280
+ finally:
281
+ try:
282
+ os.unlink(temp_path)
283
+ except:
284
+ pass
285
+
286
+ except Exception as e:
287
+ logger.error(f"Failed to transcribe {file.filename}: {e}")
288
+ results.append({
289
+ "filename": file.filename,
290
+ "error": str(e),
291
+ })
292
+
293
+ return {
294
+ "count": len(results),
295
+ "results": results,
296
+ "mode": "batched",
297
+ "batch_size": batch_size,
298
+ }
299
+
300
+
301
+ @router.post("/async-upload", response_model=AsyncTranscriptionResponse)
302
+ async def transcribe_async_upload(
303
+ file: UploadFile = File(..., description="Audio file to transcribe"),
304
+ language: str = Form(default="en-US", description="Language code"),
305
+ file_service: FileService = Depends(get_file_service),
306
+ db: Session = Depends(get_db),
307
+ ):
308
+ """
309
+ Asynchronously transcribe an uploaded audio file (Celery)
310
+ """
311
+ # Validate file type
312
+ if not file.filename:
313
+ raise HTTPException(status_code=400, detail="No filename provided")
314
+
315
+ ext = file.filename.split(".")[-1].lower()
316
+ if ext not in settings.supported_audio_formats_list:
317
+ raise HTTPException(
318
+ status_code=400,
319
+ detail=f"Unsupported format: {ext}"
320
+ )
321
+
322
+ try:
323
+ content = await file.read()
324
+ storage_path, metadata = file_service.save_upload(
325
+ file_content=content,
326
+ original_filename=file.filename,
327
+ )
328
+
329
+ # Create AudioFile record with 'queued' status
330
+ audio_file = AudioFile(
331
+ storage_path=str(storage_path),
332
+ original_filename=file.filename,
333
+ duration=0.0, # Will be updated by worker
334
+ format=ext,
335
+ sample_rate=metadata.get("sample_rate"),
336
+ language=language,
337
+ status="queued"
338
+ )
339
+ db.add(audio_file)
340
+ db.commit()
341
+ db.refresh(audio_file)
342
+
343
+ # Trigger Celery Task
344
+ task = process_audio_file.delay(audio_file.id)
345
+
346
+ return AsyncTranscriptionResponse(
347
+ task_id=task.id,
348
+ audio_file_id=audio_file.id,
349
+ status="queued",
350
+ message="File uploaded and queued for processing"
351
+ )
352
+
353
+ except Exception as e:
354
+ logger.exception(f"Async upload failed: {e}")
355
+ raise HTTPException(status_code=500, detail=str(e))
356
+
357
+
358
+ @router.get("/tasks/{task_id}", response_model=TaskStatusResponse)
359
+ async def get_task_status(task_id: str, db: Session = Depends(get_db)):
360
+ """
361
+ Check status of an async transcription task
362
+ """
363
+ task_result = AsyncResult(task_id)
364
+
365
+ response = TaskStatusResponse(
366
+ task_id=task_id,
367
+ status=task_result.status.lower(),
368
+ created_at=datetime.utcnow(), # Approximate or fetch from DB tracked tasks
369
+ updated_at=datetime.utcnow()
370
+ )
371
+
372
+ if task_result.successful():
373
+ # If successful, the result of the task function isn't returned directly
374
+ # because process_audio_file returns None (it saves to DB).
375
+ # We need to find the Transcript associated with this task if possible.
376
+ # Ideally, we should store task_id in AudioFile or Transcript to link them.
377
+ # For now, we just report completion.
378
+ response.status = "completed"
379
+ response.progress = 100.0
380
+ elif task_result.failed():
381
+ response.status = "failed"
382
+ response.error = str(task_result.result)
383
+ elif task_result.state == 'PROGRESS':
384
+ response.status = "processing"
385
+ # If we had progress updating in the task, we could read it here
386
+
387
+ return response
388
+
389
+
390
+ @router.post("/transcribe-bytes", response_model=TranscriptionResponse)
391
+ async def transcribe_bytes(
392
+ audio_content: bytes,
393
+ language: str = "en-US",
394
+ encoding: str = "LINEAR16",
395
+ sample_rate: int = 16000,
396
+ stt_service: STTService = Depends(get_stt_service),
397
+ ):
398
+ """
399
+ Transcribe raw audio bytes (for streaming/real-time use)
400
+
401
+ This endpoint is primarily for internal use or advanced clients
402
+ that send pre-processed audio data.
403
+ """
404
+ try:
405
+ result = stt_service.transcribe_bytes(
406
+ audio_content=audio_content,
407
+ language=language,
408
+ encoding=encoding,
409
+ sample_rate=sample_rate,
410
+ )
411
+ return result
412
+ except Exception as e:
413
+ logger.exception(f"Transcription failed: {e}")
414
+ raise HTTPException(status_code=500, detail=str(e))
415
+
416
+
417
+ # TODO: WebSocket endpoint for real-time streaming
418
+ # @router.websocket("/stream")
419
+ # async def stream_transcription(websocket: WebSocket):
420
+ # """Real-time streaming transcription via WebSocket"""
421
+ # pass
422
+
423
+ @router.post("/upload/diarize")
424
+ async def diarize_audio(
425
+ file: UploadFile = File(..., description="Audio file to diarize"),
426
+ num_speakers: Optional[int] = Form(None, description="Exact number of speakers (optional)"),
427
+ min_speakers: Optional[int] = Form(None, description="Minimum number of speakers (optional)"),
428
+ max_speakers: Optional[int] = Form(None, description="Maximum number of speakers (optional)"),
429
+ language: Optional[str] = Form(None, description="Language code (e.g., 'en'). Auto-detected if not provided."),
430
+ preprocess: bool = Form(False, description="Apply noise reduction before processing (improves accuracy for noisy audio)"),
431
+ ):
432
+ """
433
+ Perform Speaker Diarization ("Who said what").
434
+
435
+ Uses faster-whisper for transcription + pyannote.audio for speaker identification.
436
+
437
+ Requires:
438
+ - HF_TOKEN in .env for Pyannote model access
439
+
440
+ Returns:
441
+ - segments: List of segments with timestamps, text, and speaker labels
442
+ - speaker_stats: Speaking time per speaker
443
+ - language: Detected/specified language
444
+ """
445
+ from app.services.diarization_service import get_diarization_service
446
+ import tempfile
447
+ import os
448
+
449
+ if not file.filename:
450
+ raise HTTPException(status_code=400, detail="No filename provided")
451
+
452
+ try:
453
+ # Save temp file
454
+ content = await file.read()
455
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
456
+ f.write(content)
457
+ temp_path = f.name
458
+
459
+ try:
460
+ service = get_diarization_service()
461
+ result = service.process_audio(
462
+ temp_path,
463
+ num_speakers=num_speakers,
464
+ min_speakers=min_speakers,
465
+ max_speakers=max_speakers,
466
+ language=language,
467
+ preprocess=preprocess,
468
+ )
469
+ return result
470
+
471
+ except ValueError as e:
472
+ # Token missing
473
+ raise HTTPException(status_code=400, detail=str(e))
474
+ except ImportError as e:
475
+ # Not installed
476
+ raise HTTPException(status_code=503, detail=str(e))
477
+ except Exception as e:
478
+ logger.exception("Diarization error")
479
+ raise HTTPException(status_code=500, detail=f"Diarization failed: {str(e)}")
480
+
481
+ finally:
482
+ try:
483
+ os.unlink(temp_path)
484
+ except:
485
+ pass
486
+
487
+ except Exception as e:
488
+ logger.error(f"Diarization request failed: {e}")
489
+ raise HTTPException(status_code=500, detail=str(e))
backend/app/api/routes/transcripts.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transcript Management Routes
3
+ CRUD operations and Export
4
+ """
5
+
6
+ from typing import List, Optional
7
+ from fastapi import APIRouter, Depends, HTTPException, Response, Query, UploadFile, File, Form
8
+ from sqlalchemy.orm import Session
9
+ from datetime import datetime
10
+
11
+ from ...models import get_db, Transcript, AudioFile
12
+ from ...schemas.transcript import TranscriptResponse, TranscriptUpdate
13
+ from ...services.nlp_service import get_nlp_service, NLPService
14
+ from ...services.export_service import ExportService
15
+
16
+
17
+ router = APIRouter(prefix="/transcripts", tags=["Transcripts"])
18
+
19
+
20
+ @router.get("", response_model=List[TranscriptResponse])
21
+ async def list_transcripts(
22
+ skip: int = 0,
23
+ limit: int = 100,
24
+ db: Session = Depends(get_db),
25
+ ):
26
+ """List all transcripts"""
27
+ transcripts = db.query(Transcript).order_by(Transcript.created_at.desc()).offset(skip).limit(limit).all()
28
+ return transcripts
29
+
30
+
31
+ @router.get("/{transcript_id}", response_model=TranscriptResponse)
32
+ async def get_transcript(
33
+ transcript_id: int,
34
+ db: Session = Depends(get_db),
35
+ ):
36
+ """Get specific transcript details"""
37
+ transcript = db.query(Transcript).filter(Transcript.id == transcript_id).first()
38
+ if not transcript:
39
+ raise HTTPException(status_code=404, detail="Transcript not found")
40
+ return transcript
41
+
42
+
43
+ @router.post("/{transcript_id}/analyze")
44
+ async def analyze_transcript(
45
+ transcript_id: int,
46
+ db: Session = Depends(get_db),
47
+ nlp_service: NLPService = Depends(get_nlp_service),
48
+ ):
49
+ """Run NLP analysis on a transcript"""
50
+ transcript = db.query(Transcript).filter(Transcript.id == transcript_id).first()
51
+ if not transcript:
52
+ raise HTTPException(status_code=404, detail="Transcript not found")
53
+
54
+ if not transcript.processed_text:
55
+ raise HTTPException(status_code=400, detail="Transcript has no text content")
56
+
57
+ # Run analysis
58
+ analysis = nlp_service.process_transcript(transcript.processed_text)
59
+
60
+ # Update DB
61
+ transcript.sentiment = analysis["sentiment"]
62
+ transcript.topics = {"keywords": analysis["keywords"]}
63
+ transcript.summary = analysis["summary"]
64
+ transcript.updated_at = datetime.utcnow()
65
+
66
+ db.commit()
67
+ db.refresh(transcript)
68
+
69
+ return {
70
+ "status": "success",
71
+ "analysis": analysis
72
+ }
73
+
74
+
75
+ @router.get("/{transcript_id}/export")
76
+ async def export_transcript(
77
+ transcript_id: int,
78
+ format: str = Query(..., regex="^(txt|srt|vtt|pdf)$"),
79
+ db: Session = Depends(get_db),
80
+ ):
81
+ """
82
+ Export transcript to specific format
83
+ """
84
+ transcript = db.query(Transcript).filter(Transcript.id == transcript_id).first()
85
+ if not transcript:
86
+ raise HTTPException(status_code=404, detail="Transcript not found")
87
+
88
+ # Convert model to dict for service
89
+ data = {
90
+ "id": transcript.id,
91
+ "text": transcript.processed_text,
92
+ "created_at": str(transcript.created_at),
93
+ "duration": 0,
94
+ "segments": transcript.segments,
95
+ "words": [],
96
+ "sentiment": transcript.sentiment,
97
+ }
98
+
99
+ if format == "txt":
100
+ content = ExportService.to_txt(data)
101
+ media_type = "text/plain"
102
+ elif format == "srt":
103
+ content = ExportService.to_srt(data)
104
+ media_type = "text/plain"
105
+ elif format == "vtt":
106
+ content = ExportService.to_vtt(data)
107
+ media_type = "text/vtt"
108
+ elif format == "pdf":
109
+ content = ExportService.to_pdf(data)
110
+ media_type = "application/pdf"
111
+ else:
112
+ raise HTTPException(status_code=400, detail="Unsupported format")
113
+
114
+ return Response(
115
+ content=content,
116
+ media_type=media_type,
117
+ headers={
118
+ "Content-Disposition": f'attachment; filename="transcript_{transcript_id}.{format}"'
119
+ }
120
+ )
121
+ @router.post("/meeting")
122
+ async def process_meeting(
123
+ file: UploadFile = File(..., description="Audio recording of meeting"),
124
+ num_speakers: Optional[int] = Form(None, description="Number of speakers (hint)"),
125
+ language: Optional[str] = Form(None, description="Language code"),
126
+ db: Session = Depends(get_db),
127
+ ):
128
+ """
129
+ Process a meeting recording:
130
+ 1. Diarization (Who spoke when)
131
+ 2. Transcription (What was said)
132
+ 3. NLP Analysis (Summary, Action Items, Sentiment)
133
+ 4. Save to DB
134
+ """
135
+ import shutil
136
+ import os
137
+ import tempfile
138
+ from ...services.meeting_service import get_meeting_service
139
+
140
+ # Save upload to temp file
141
+ suffix = os.path.splitext(file.filename)[1] or ".wav"
142
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
143
+ shutil.copyfileobj(file.file, tmp)
144
+ tmp_path = tmp.name
145
+
146
+ try:
147
+ meeting_service = get_meeting_service()
148
+
149
+ # Run full pipeline
150
+ # This can be slow (minutes) so strictly speaking should be a background task
151
+ # But for this MVP level we'll do it synchronously with a long timeout
152
+ result = meeting_service.process_meeting(
153
+ audio_path=tmp_path,
154
+ num_speakers=num_speakers,
155
+ language=language
156
+ )
157
+
158
+ # Save to DB
159
+ # Create AudioFile record first
160
+ audio_file = AudioFile(
161
+ filename=file.filename,
162
+ filepath="processed_in_memory", # We delete temp file, so no perm path
163
+ duration=result["metadata"]["duration_seconds"],
164
+ file_size=0,
165
+ format=suffix.replace(".", "")
166
+ )
167
+ db.add(audio_file)
168
+ db.commit()
169
+ db.refresh(audio_file)
170
+
171
+ # Create Transcript record
172
+ transcript = Transcript(
173
+ audio_file_id=audio_file.id,
174
+ raw_text=result["raw_text"],
175
+ processed_text=result["raw_text"],
176
+ segments=result["transcript_segments"],
177
+ sentiment=result["sentiment"],
178
+ topics={"keywords": result["topics"]},
179
+ action_items=result["action_items"],
180
+ attendees=result["metadata"]["attendees"],
181
+ summary=result["summary"],
182
+ language=result["metadata"]["language"],
183
+ confidence=0.95, # Estimated
184
+ duration=result["metadata"]["duration_seconds"],
185
+ created_at=datetime.utcnow()
186
+ )
187
+ db.add(transcript)
188
+ db.commit()
189
+ db.refresh(transcript)
190
+
191
+ return result
192
+
193
+ except Exception as e:
194
+ raise HTTPException(status_code=500, detail=str(e))
195
+ finally:
196
+ # Cleanup
197
+ try:
198
+ os.unlink(tmp_path)
199
+ except:
200
+ pass
backend/app/api/routes/translation.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Translation API Routes
3
+ Endpoints for text and audio translation services
4
+ """
5
+
6
+ from fastapi import APIRouter, HTTPException, UploadFile, File, Form
7
+ from pydantic import BaseModel, Field
8
+ from typing import Optional, List
9
+ import logging
10
+
11
+ from app.services.translation_service import get_translation_service
12
+
13
+ logger = logging.getLogger(__name__)
14
+ router = APIRouter(prefix="/translation", tags=["translation"])
15
+
16
+
17
+ # Request/Response Models
18
+ class TranslateTextRequest(BaseModel):
19
+ """Request model for text translation."""
20
+ text: str = Field(..., min_length=1, max_length=5000, description="Text to translate")
21
+ source_lang: str = Field(..., description="Source language code (e.g., 'hi', 'en-US')")
22
+ target_lang: str = Field(..., description="Target language code (e.g., 'en', 'es')")
23
+ use_pivot: bool = Field(default=True, description="Use English as pivot for unsupported pairs")
24
+
25
+
26
+ class TranslateTextResponse(BaseModel):
27
+ """Response model for text translation."""
28
+ translated_text: str
29
+ source_lang: str
30
+ target_lang: str
31
+ source_text: str
32
+ processing_time: float
33
+ word_count: int
34
+ pivot_used: Optional[bool] = False
35
+ intermediate_text: Optional[str] = None
36
+ model_used: Optional[str] = None
37
+
38
+
39
+ class LanguageInfo(BaseModel):
40
+ """Language information model."""
41
+ code: str
42
+ name: str
43
+ flag: str
44
+ native: str
45
+
46
+
47
+ class TranslationPair(BaseModel):
48
+ """Translation pair model."""
49
+ code: str
50
+ source: LanguageInfo
51
+ target: LanguageInfo
52
+
53
+
54
+ class DetectLanguageResponse(BaseModel):
55
+ """Response model for language detection."""
56
+ detected_language: str
57
+ confidence: float
58
+ language_info: Optional[dict] = None
59
+ all_probabilities: Optional[List[dict]] = None
60
+
61
+
62
+ # Endpoints
63
+ @router.get("/languages", response_model=List[LanguageInfo])
64
+ async def get_supported_languages():
65
+ """
66
+ Get list of all supported languages.
67
+
68
+ Returns:
69
+ List of supported languages with metadata
70
+ """
71
+ service = get_translation_service()
72
+ return service.get_supported_languages()
73
+
74
+
75
+ @router.get("/pairs")
76
+ async def get_supported_pairs():
77
+ """
78
+ Get list of all supported translation pairs.
79
+
80
+ Returns:
81
+ List of supported source->target language pairs
82
+ """
83
+ service = get_translation_service()
84
+ return {
85
+ "pairs": service.get_supported_pairs(),
86
+ "total": len(service.get_supported_pairs()),
87
+ }
88
+
89
+
90
+ @router.post("/text", response_model=TranslateTextResponse)
91
+ async def translate_text(request: TranslateTextRequest):
92
+ """
93
+ Translate text from source to target language.
94
+
95
+ - Uses Helsinki-NLP MarianMT models (~300MB per language pair)
96
+ - Supports pivot translation through English for unsupported pairs
97
+ - First request for a language pair may take longer (model loading)
98
+
99
+ Args:
100
+ request: Translation request with text and language codes
101
+
102
+ Returns:
103
+ Translated text with metadata
104
+ """
105
+ service = get_translation_service()
106
+
107
+ try:
108
+ if request.use_pivot:
109
+ result = service.translate_with_pivot(
110
+ text=request.text,
111
+ source_lang=request.source_lang,
112
+ target_lang=request.target_lang,
113
+ )
114
+ else:
115
+ result = service.translate_text(
116
+ text=request.text,
117
+ source_lang=request.source_lang,
118
+ target_lang=request.target_lang,
119
+ )
120
+
121
+ return TranslateTextResponse(**result)
122
+
123
+ except ValueError as e:
124
+ raise HTTPException(status_code=400, detail=str(e))
125
+ except Exception as e:
126
+ logger.error(f"Translation error: {e}")
127
+ raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
128
+
129
+
130
+ @router.post("/detect", response_model=DetectLanguageResponse)
131
+ async def detect_language(text: str = Form(..., min_length=10, description="Text to analyze")):
132
+ """
133
+ Detect the language of input text.
134
+
135
+ Args:
136
+ text: Text to analyze (minimum 10 characters for accuracy)
137
+
138
+ Returns:
139
+ Detected language with confidence score
140
+ """
141
+ service = get_translation_service()
142
+ result = service.detect_language(text)
143
+
144
+ if result.get("error"):
145
+ raise HTTPException(status_code=400, detail=result["error"])
146
+
147
+ return DetectLanguageResponse(**result)
148
+
149
+
150
+ @router.get("/model-info")
151
+ async def get_model_info():
152
+ """
153
+ Get information about loaded translation models.
154
+
155
+ Returns:
156
+ Model loading status and supported pairs
157
+ """
158
+ service = get_translation_service()
159
+ return service.get_model_info()
160
+
161
+
162
+ @router.post("/audio")
163
+ async def translate_audio(
164
+ file: UploadFile = File(..., description="Audio file to translate"),
165
+ source_lang: str = Form(..., description="Source language code"),
166
+ target_lang: str = Form(..., description="Target language code"),
167
+ generate_audio: bool = Form(default=True, description="Generate TTS output"),
168
+ ):
169
+ """
170
+ Full audio translation pipeline: STT → Translate → TTS
171
+
172
+ 1. Transcribe audio using Whisper
173
+ 2. Translate text using MarianMT
174
+ 3. Optionally generate speech in target language
175
+
176
+ Args:
177
+ file: Audio file (WAV, MP3, etc.)
178
+ source_lang: Source language code
179
+ target_lang: Target language code
180
+ generate_audio: Whether to generate TTS output
181
+
182
+ Returns:
183
+ Transcription, translation, and optional audio response
184
+ """
185
+ import tempfile
186
+ import os
187
+ from app.services.whisper_stt_service import get_whisper_stt_service
188
+ from app.services.edge_tts_service import get_edge_tts_service
189
+
190
+ translation_service = get_translation_service()
191
+ stt_service = get_whisper_stt_service()
192
+ tts_service = get_edge_tts_service()
193
+
194
+ # Save uploaded file
195
+ suffix = os.path.splitext(file.filename)[1] or ".wav"
196
+ with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
197
+ content = await file.read()
198
+ tmp.write(content)
199
+ tmp_path = tmp.name
200
+
201
+ try:
202
+ # Step 1: Transcribe
203
+ transcription = stt_service.transcribe_file(tmp_path, language=source_lang)
204
+ source_text = transcription["text"]
205
+
206
+ if not source_text.strip():
207
+ raise HTTPException(status_code=400, detail="No speech detected in audio")
208
+
209
+ # Step 2: Translate
210
+ translation = translation_service.translate_with_pivot(
211
+ text=source_text,
212
+ source_lang=source_lang,
213
+ target_lang=target_lang,
214
+ )
215
+ translated_text = translation["translated_text"]
216
+
217
+ # Step 3: Generate TTS (optional)
218
+ audio_base64 = None
219
+ if generate_audio:
220
+ # Map language code to voice
221
+ voice_map = {
222
+ "en": "en-US-AriaNeural",
223
+ "hi": "hi-IN-SwaraNeural",
224
+ "es": "es-ES-ElviraNeural",
225
+ "fr": "fr-FR-DeniseNeural",
226
+ "de": "de-DE-KatjaNeural",
227
+ "zh": "zh-CN-XiaoxiaoNeural",
228
+ "ja": "ja-JP-NanamiNeural",
229
+ "ko": "ko-KR-SunHiNeural",
230
+ "ar": "ar-SA-ZariyahNeural",
231
+ "ru": "ru-RU-SvetlanaNeural",
232
+ }
233
+ target_code = target_lang.split("-")[0].lower()
234
+ voice = voice_map.get(target_code, "en-US-AriaNeural")
235
+
236
+ audio_bytes = tts_service.synthesize_sync(translated_text, voice=voice)
237
+
238
+ import base64
239
+ audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
240
+
241
+ return {
242
+ "source_text": source_text,
243
+ "translated_text": translated_text,
244
+ "source_lang": source_lang,
245
+ "target_lang": target_lang,
246
+ "transcription_time": transcription["processing_time"],
247
+ "translation_time": translation["processing_time"],
248
+ "audio_base64": audio_base64,
249
+ "audio_format": "mp3" if audio_base64 else None,
250
+ }
251
+
252
+ except HTTPException:
253
+ raise
254
+ except Exception as e:
255
+ logger.error(f"Audio translation failed: {e}")
256
+ raise HTTPException(status_code=500, detail=str(e))
257
+ finally:
258
+ try:
259
+ os.unlink(tmp_path)
260
+ except:
261
+ pass
backend/app/api/routes/tts.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text-to-Speech API Router
3
+ """
4
+
5
+ import base64
6
+ import logging
7
+ from typing import Optional
8
+ from fastapi import APIRouter, HTTPException, Depends, Response, Request
9
+ from fastapi.responses import StreamingResponse
10
+ from io import BytesIO
11
+
12
+ from ...core.limiter import limiter
13
+
14
+ from ...services.tts_service import get_tts_service, TTSService
15
+ from ...schemas.tts import (
16
+ SynthesisRequest,
17
+ SynthesisResponse,
18
+ VoiceInfo,
19
+ VoiceListResponse,
20
+ VoicePreviewRequest,
21
+ )
22
+ from ...core.config import get_settings
23
+
24
+ logger = logging.getLogger(__name__)
25
+ router = APIRouter(prefix="/tts", tags=["Text-to-Speech"])
26
+ settings = get_settings()
27
+
28
+
29
+ @router.get("/voices", response_model=VoiceListResponse)
30
+ async def get_voices(
31
+ language: Optional[str] = None,
32
+ tts_service: TTSService = Depends(get_tts_service),
33
+ ):
34
+ """
35
+ Get list of available TTS voices
36
+
37
+ Optionally filter by language code (e.g., "en-US", "es", "fr")
38
+ """
39
+ return await tts_service.get_voices(language_code=language)
40
+
41
+
42
+ @router.get("/voices/{language}", response_model=VoiceListResponse)
43
+ async def get_voices_by_language(
44
+ language: str,
45
+ tts_service: TTSService = Depends(get_tts_service),
46
+ ):
47
+ """
48
+ Get voices for a specific language
49
+ """
50
+ if language not in settings.supported_languages_list:
51
+ # Try partial match (e.g., "en" matches "en-US", "en-GB")
52
+ partial_matches = [l for l in settings.supported_languages_list if l.startswith(language)]
53
+ if not partial_matches:
54
+ raise HTTPException(
55
+ status_code=400,
56
+ detail=f"Unsupported language: {language}"
57
+ )
58
+
59
+ return await tts_service.get_voices(language_code=language)
60
+
61
+
62
+ @router.post("/synthesize", response_model=SynthesisResponse)
63
+ @limiter.limit("10/minute")
64
+ async def synthesize_speech(
65
+ request: Request,
66
+ request_body: SynthesisRequest,
67
+ tts_service: TTSService = Depends(get_tts_service),
68
+ ):
69
+ """
70
+ Synthesize text to speech
71
+
72
+ Returns base64-encoded audio content along with metadata.
73
+ Decode the audio_content field to get the audio bytes.
74
+ """
75
+ # Validate text length
76
+ if len(request_body.text) > 5000:
77
+ raise HTTPException(
78
+ status_code=400,
79
+ detail="Text too long. Maximum 5000 characters."
80
+ )
81
+
82
+ # Validate language
83
+ lang_base = request_body.language.split("-")[0] if "-" in request_body.language else request_body.language
84
+ supported_bases = [l.split("-")[0] for l in settings.supported_languages_list]
85
+ if lang_base not in supported_bases:
86
+ raise HTTPException(
87
+ status_code=400,
88
+ detail=f"Unsupported language: {request_body.language}"
89
+ )
90
+
91
+ try:
92
+ result = await tts_service.synthesize(request_body)
93
+ return result
94
+ except ValueError as e:
95
+ logger.error(f"Synthesis validation error: {e}")
96
+ raise HTTPException(status_code=400, detail=str(e))
97
+ except Exception as e:
98
+ logger.exception(f"Synthesis failed: {e}")
99
+ raise HTTPException(status_code=500, detail=f"Synthesis failed: {str(e)}")
100
+
101
+
102
+ @router.post("/stream")
103
+ async def stream_speech(
104
+ request: SynthesisRequest,
105
+ tts_service: TTSService = Depends(get_tts_service),
106
+ ):
107
+ """
108
+ Stream text-to-speech audio
109
+
110
+ Returns a chunked audio stream (audio/mpeg) for immediate playback.
111
+ Best for long text to reduce latency (TTFB).
112
+ """
113
+ try:
114
+ return StreamingResponse(
115
+ tts_service.synthesize_stream(request),
116
+ media_type="audio/mpeg"
117
+ )
118
+ except Exception as e:
119
+ logger.exception(f"Streaming synthesis failed: {e}")
120
+ raise HTTPException(status_code=500, detail=str(e))
121
+
122
+
123
+ @router.post("/ssml")
124
+ async def synthesize_ssml(
125
+ text: str,
126
+ voice: str = "en-US-AriaNeural",
127
+ rate: str = "medium",
128
+ pitch: str = "medium",
129
+ emphasis: Optional[str] = None,
130
+ auto_breaks: bool = True,
131
+ tts_service: TTSService = Depends(get_tts_service),
132
+ ):
133
+ """
134
+ Synthesize speech with SSML prosody control
135
+
136
+ Supports advanced speech customization:
137
+ - rate: 'x-slow', 'slow', 'medium', 'fast', 'x-fast'
138
+ - pitch: 'x-low', 'low', 'medium', 'high', 'x-high'
139
+ - emphasis: 'reduced', 'moderate', 'strong'
140
+ - auto_breaks: Add natural pauses at punctuation
141
+
142
+ Returns audio/mpeg stream.
143
+ """
144
+ try:
145
+ from ...services.edge_tts_service import get_edge_tts_service
146
+ edge_service = get_edge_tts_service()
147
+
148
+ # Build SSML
149
+ ssml = edge_service.build_ssml(
150
+ text=text,
151
+ voice=voice,
152
+ rate=rate,
153
+ pitch=pitch,
154
+ emphasis=emphasis,
155
+ breaks=auto_breaks
156
+ )
157
+
158
+ # Synthesize
159
+ audio_bytes = await edge_service.synthesize_ssml(ssml, voice)
160
+
161
+ return Response(
162
+ content=audio_bytes,
163
+ media_type="audio/mpeg",
164
+ headers={"Content-Disposition": "inline; filename=speech.mp3"}
165
+ )
166
+ except Exception as e:
167
+ logger.exception(f"SSML synthesis failed: {e}")
168
+ raise HTTPException(status_code=500, detail=str(e))
169
+
170
+
171
+ @router.post("/synthesize/audio")
172
+ async def synthesize_audio_file(
173
+ request: SynthesisRequest,
174
+ tts_service: TTSService = Depends(get_tts_service),
175
+ ):
176
+ """
177
+ Synthesize text and return audio file directly
178
+
179
+ Returns the audio file as a downloadable stream.
180
+ """
181
+ try:
182
+ result = await tts_service.synthesize(request)
183
+
184
+ # Decode base64 audio
185
+ audio_bytes = base64.b64decode(result.audio_content)
186
+
187
+ # Determine content type
188
+ content_types = {
189
+ "MP3": "audio/mpeg",
190
+ "LINEAR16": "audio/wav",
191
+ "OGG_OPUS": "audio/ogg",
192
+ }
193
+ content_type = content_types.get(result.encoding, "audio/mpeg")
194
+
195
+ # Return as streaming response
196
+ return StreamingResponse(
197
+ BytesIO(audio_bytes),
198
+ media_type=content_type,
199
+ headers={
200
+ "Content-Disposition": f'attachment; filename="speech.{result.encoding.lower()}"',
201
+ "Content-Length": str(result.audio_size),
202
+ }
203
+ )
204
+ except Exception as e:
205
+ logger.exception(f"Audio synthesis failed: {e}")
206
+ raise HTTPException(status_code=500, detail=str(e))
207
+
208
+
209
+ @router.post("/preview")
210
+ async def preview_voice(
211
+ request: VoicePreviewRequest,
212
+ tts_service: TTSService = Depends(get_tts_service),
213
+ ):
214
+ """
215
+ Generate a short preview of a voice
216
+
217
+ Returns a small audio sample for voice selection UI.
218
+ """
219
+ # Find the voice to get its language
220
+ voices = tts_service.get_voices().voices
221
+ voice_info = next((v for v in voices if v.name == request.voice), None)
222
+
223
+ if not voice_info:
224
+ raise HTTPException(status_code=404, detail=f"Voice not found: {request.voice}")
225
+
226
+ # Create synthesis request with preview text
227
+ synth_request = SynthesisRequest(
228
+ text=request.text or "Hello! This is a preview of my voice.",
229
+ language=voice_info.language_code,
230
+ voice=request.voice,
231
+ audio_encoding="MP3",
232
+ )
233
+
234
+ try:
235
+ result = tts_service.synthesize(synth_request)
236
+
237
+ # Return audio directly
238
+ audio_bytes = base64.b64decode(result.audio_content)
239
+ return StreamingResponse(
240
+ BytesIO(audio_bytes),
241
+ media_type="audio/mpeg",
242
+ )
243
+ except Exception as e:
244
+ logger.exception(f"Preview failed: {e}")
245
+ raise HTTPException(status_code=500, detail=str(e))
backend/app/api/routes/ws.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebSocket Router for Real-Time Transcription
3
+ """
4
+
5
+ import logging
6
+ import json
7
+ from typing import Dict, Optional
8
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
9
+
10
+ from app.core.ws_security import (
11
+ validate_ws_origin,
12
+ authenticate_websocket,
13
+ ws_rate_limiter
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+ router = APIRouter(prefix="/ws", tags=["WebSocket"])
18
+
19
+
20
+ class ConnectionManager:
21
+ """Manages active WebSocket connections"""
22
+
23
+ def __init__(self):
24
+ self.active_connections: Dict[str, WebSocket] = {}
25
+ self.user_ids: Dict[str, Optional[int]] = {} # Track authenticated users
26
+
27
+ async def connect(
28
+ self,
29
+ client_id: str,
30
+ websocket: WebSocket,
31
+ user_id: Optional[int] = None
32
+ ) -> bool:
33
+ """
34
+ Connect a client after validation.
35
+ Returns True if connection accepted, False if rejected.
36
+ """
37
+ # Validate origin
38
+ if not validate_ws_origin(websocket):
39
+ logger.warning(f"WebSocket rejected for {client_id}: invalid origin")
40
+ await websocket.close(code=1008) # Policy Violation
41
+ return False
42
+
43
+ await websocket.accept()
44
+ self.active_connections[client_id] = websocket
45
+ self.user_ids[client_id] = user_id
46
+ logger.info(f"Client {client_id} connected (user_id={user_id})")
47
+ return True
48
+
49
+ def disconnect(self, client_id: str):
50
+ if client_id in self.active_connections:
51
+ del self.active_connections[client_id]
52
+ self.user_ids.pop(client_id, None)
53
+ ws_rate_limiter.cleanup(client_id)
54
+ logger.info(f"Client {client_id} disconnected")
55
+
56
+ async def send_json(self, client_id: str, data: dict):
57
+ if client_id in self.active_connections:
58
+ await self.active_connections[client_id].send_json(data)
59
+
60
+
61
+ manager = ConnectionManager()
62
+
63
+
64
+ @router.websocket("/transcription/{client_id}")
65
+ async def websocket_transcription(
66
+ websocket: WebSocket,
67
+ client_id: str,
68
+ token: Optional[str] = Query(None)
69
+ ):
70
+ """
71
+ Real-time streaming transcription via WebSocket with VAD.
72
+
73
+ Optional auth via query param: ws://host/ws/transcription/{id}?token=jwt_token
74
+ """
75
+ # Authenticate (optional for demo, but logged)
76
+ user_id = await authenticate_websocket(websocket, token)
77
+
78
+ if not await manager.connect(client_id, websocket, user_id):
79
+ return # Connection rejected
80
+
81
+ from app.services.ws_stt_service import StreamManager, transcribe_buffer
82
+
83
+ stream_manager = StreamManager(websocket)
84
+
85
+ async def handle_transcription(audio_bytes: bytes):
86
+ """Callback for processing speech segments."""
87
+ # Check rate limit
88
+ if not ws_rate_limiter.check_rate(client_id):
89
+ await manager.send_json(client_id, {"error": "Rate limit exceeded"})
90
+ return
91
+
92
+ # Check message size
93
+ if not ws_rate_limiter.check_size(audio_bytes):
94
+ await manager.send_json(client_id, {"error": "Message too large"})
95
+ return
96
+
97
+ try:
98
+ # Send processing status
99
+ await manager.send_json(client_id, {"status": "processing"})
100
+
101
+ # Transcribe
102
+ result = await transcribe_buffer(audio_bytes)
103
+ text = result.get("text", "").strip()
104
+
105
+ if text:
106
+ # Send result
107
+ await manager.send_json(client_id, {
108
+ "text": text,
109
+ "is_final": True,
110
+ "status": "complete"
111
+ })
112
+ logger.info(f"Transcribed: {text}")
113
+ except Exception as e:
114
+ logger.error(f"Transcription callback error: {e}")
115
+ await manager.send_json(client_id, {"error": str(e)})
116
+
117
+ try:
118
+ # Start processing loop
119
+ await stream_manager.process_stream(handle_transcription)
120
+
121
+ except WebSocketDisconnect:
122
+ manager.disconnect(client_id)
123
+ except Exception as e:
124
+ logger.error(f"WebSocket error: {e}")
125
+ try:
126
+ await manager.send_json(client_id, {"error": str(e)})
127
+ except:
128
+ pass
129
+ manager.disconnect(client_id)
130
+
131
+
132
+ @router.websocket("/tts/{client_id}")
133
+ async def websocket_tts(
134
+ websocket: WebSocket,
135
+ client_id: str,
136
+ token: Optional[str] = Query(None)
137
+ ):
138
+ """
139
+ Real-time Text-to-Speech via WebSocket
140
+
141
+ Protocol:
142
+ - Client sends: JSON {"text": "...", "voice": "...", "rate": "...", "pitch": "..."}
143
+ - Server sends: Binary audio chunks (MP3) followed by JSON {"status": "complete"}
144
+
145
+ Optional auth via query param: ws://host/ws/tts/{id}?token=jwt_token
146
+ This achieves <500ms TTFB by streaming as chunks are generated.
147
+ """
148
+ # Authenticate (optional for demo, but logged)
149
+ user_id = await authenticate_websocket(websocket, token)
150
+
151
+ if not await manager.connect(client_id, websocket, user_id):
152
+ return # Connection rejected
153
+
154
+ try:
155
+ import edge_tts
156
+
157
+ while True:
158
+ # Receive synthesis request
159
+ data = await websocket.receive_json()
160
+
161
+ text = data.get("text", "")
162
+ voice = data.get("voice", "en-US-AriaNeural")
163
+ rate = data.get("rate", "+0%")
164
+ pitch = data.get("pitch", "+0Hz")
165
+
166
+ if not text:
167
+ await websocket.send_json({"error": "No text provided"})
168
+ continue
169
+
170
+ logger.info(f"WebSocket TTS: Synthesizing '{text[:50]}...' with {voice}")
171
+
172
+ # Stream audio chunks directly
173
+ import time
174
+ start_time = time.time()
175
+ first_chunk_sent = False
176
+ total_bytes = 0
177
+
178
+ communicate = edge_tts.Communicate(text, voice, rate=rate, pitch=pitch)
179
+
180
+ async for chunk in communicate.stream():
181
+ if chunk["type"] == "audio":
182
+ await websocket.send_bytes(chunk["data"])
183
+ total_bytes += len(chunk["data"])
184
+
185
+ if not first_chunk_sent:
186
+ ttfb = (time.time() - start_time) * 1000
187
+ logger.info(f"WebSocket TTS TTFB: {ttfb:.0f}ms")
188
+ first_chunk_sent = True
189
+
190
+ # Send completion marker
191
+ total_time = time.time() - start_time
192
+ await websocket.send_json({
193
+ "status": "complete",
194
+ "total_bytes": total_bytes,
195
+ "total_time_ms": round(total_time * 1000),
196
+ "ttfb_ms": round(ttfb) if first_chunk_sent else None
197
+ })
198
+
199
+ except WebSocketDisconnect:
200
+ manager.disconnect(client_id)
201
+ except Exception as e:
202
+ logger.error(f"WebSocket TTS error: {e}")
203
+ try:
204
+ await websocket.send_json({"error": str(e)})
205
+ except:
206
+ pass
207
+ manager.disconnect(client_id)
208
+
backend/app/core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceForge Core Package
3
+ """
4
+
5
+ from .config import get_settings, Settings, LANGUAGE_METADATA
6
+
7
+ __all__ = ["get_settings", "Settings", "LANGUAGE_METADATA"]
backend/app/core/config.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceForge Configuration
3
+ Pydantic Settings for application configuration
4
+ """
5
+
6
+ from functools import lru_cache
7
+ from typing import List
8
+ from pydantic_settings import BaseSettings, SettingsConfigDict
9
+ from pydantic import Field
10
+
11
+
12
+ class Settings(BaseSettings):
13
+ """Application settings loaded from environment variables"""
14
+
15
+ model_config = SettingsConfigDict(
16
+ env_file=".env",
17
+ env_file_encoding="utf-8",
18
+ case_sensitive=False,
19
+ extra="allow", # Allow extra env vars without error
20
+ )
21
+
22
+ # Application
23
+ app_name: str = "VoiceForge"
24
+ app_version: str = "1.0.0"
25
+ debug: bool = False
26
+
27
+ # API Server
28
+ api_host: str = "0.0.0.0"
29
+ api_port: int = 8000
30
+
31
+ # Database
32
+ database_url: str = Field(
33
+ default="sqlite:///./voiceforge.db",
34
+ description="Database connection URL (SQLite for dev, PostgreSQL for prod)"
35
+ )
36
+
37
+ # Redis
38
+ redis_url: str = Field(
39
+ default="redis://localhost:6379/0",
40
+ description="Redis connection URL for caching and Celery"
41
+ )
42
+
43
+ # Google Cloud
44
+ google_application_credentials: str = Field(
45
+ default="./credentials/google-cloud-key.json",
46
+ description="Path to Google Cloud service account JSON key"
47
+ )
48
+
49
+ # AI Services Configuration
50
+ use_local_services: bool = Field(
51
+ default=True,
52
+ description="Use local free services (Whisper + EdgeTTS) instead of Google Cloud"
53
+ )
54
+ whisper_model: str = Field(
55
+ default="small",
56
+ description="Whisper model size (tiny, base, small, medium, large-v3)"
57
+ )
58
+
59
+ # Security
60
+ secret_key: str = Field(
61
+ default="your-super-secret-key-change-in-production",
62
+ description="Secret key for JWT encoding"
63
+ )
64
+ access_token_expire_minutes: int = 30
65
+ algorithm: str = "HS256"
66
+ hf_token: str | None = Field(default=None, description="Hugging Face Token for Diarization")
67
+
68
+ # File Storage
69
+ upload_dir: str = "./uploads"
70
+ max_audio_duration_seconds: int = 600 # 10 minutes
71
+ max_upload_size_mb: int = 50
72
+
73
+ # Supported Languages
74
+ supported_languages: str = "en-US,en-GB,es-ES,es-MX,fr-FR,de-DE,ja-JP,ko-KR,zh-CN,hi-IN"
75
+
76
+ # Audio Formats
77
+ supported_audio_formats: str = "wav,mp3,m4a,flac,ogg,webm"
78
+
79
+ @property
80
+ def supported_languages_list(self) -> List[str]:
81
+ """Get supported languages as a list"""
82
+ return [lang.strip() for lang in self.supported_languages.split(",")]
83
+
84
+ @property
85
+ def supported_audio_formats_list(self) -> List[str]:
86
+ """Get supported audio formats as a list"""
87
+ return [fmt.strip() for fmt in self.supported_audio_formats.split(",")]
88
+
89
+
90
+ # Language metadata for UI display
91
+ LANGUAGE_METADATA = {
92
+ "en-US": {"name": "English (US)", "flag": "🇺🇸", "native": "English"},
93
+ "en-GB": {"name": "English (UK)", "flag": "🇬🇧", "native": "English"},
94
+ "es-ES": {"name": "Spanish (Spain)", "flag": "🇪🇸", "native": "Español"},
95
+ "es-MX": {"name": "Spanish (Mexico)", "flag": "🇲🇽", "native": "Español"},
96
+ "fr-FR": {"name": "French", "flag": "🇫🇷", "native": "Français"},
97
+ "de-DE": {"name": "German", "flag": "🇩🇪", "native": "Deutsch"},
98
+ "ja-JP": {"name": "Japanese", "flag": "🇯🇵", "native": "日本語"},
99
+ "ko-KR": {"name": "Korean", "flag": "🇰🇷", "native": "한국어"},
100
+ "zh-CN": {"name": "Chinese (Mandarin)", "flag": "🇨🇳", "native": "中文"},
101
+ "hi-IN": {"name": "Hindi", "flag": "🇮🇳", "native": "हिन्दी"},
102
+ }
103
+
104
+
105
+ @lru_cache
106
+ def get_settings() -> Settings:
107
+ """Get cached settings instance"""
108
+ return Settings()
backend/app/core/limiter.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from slowapi import Limiter
3
+ from slowapi.util import get_remote_address
4
+ from slowapi.errors import RateLimitExceeded
5
+
6
+ # Initialize Limiter
7
+ # Use in-memory storage for local dev (Redis for production)
8
+ redis_url = os.getenv("REDIS_URL")
9
+
10
+ # For local testing without Redis, use memory storage
11
+ if redis_url and redis_url.strip():
12
+ try:
13
+ import redis
14
+ r = redis.from_url(redis_url)
15
+ r.ping() # Test connection
16
+ storage_uri = redis_url
17
+ except Exception:
18
+ # Redis not available, fall back to memory
19
+ storage_uri = "memory://"
20
+ else:
21
+ storage_uri = "memory://"
22
+
23
+ limiter = Limiter(
24
+ key_func=get_remote_address,
25
+ storage_uri=storage_uri,
26
+ default_limits=["60/minute"] # Global limit: 60 req/min per IP
27
+ )
backend/app/core/middleware.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Rate Limiting Middleware
3
+ Uses Redis to track and limit request rates per IP address.
4
+ Pure ASGI implementation to avoid BaseHTTPMiddleware issues.
5
+ """
6
+
7
+ import time
8
+ import redis
9
+ from starlette.responses import JSONResponse
10
+ from starlette.types import ASGIApp, Scope, Receive, Send
11
+ from ..core.config import get_settings
12
+
13
+ settings = get_settings()
14
+
15
+ class RateLimitMiddleware:
16
+ def __init__(self, app: ASGIApp):
17
+ self.app = app
18
+ # Hardcoded or from settings (bypassing constructor arg issue)
19
+ self.requests_per_minute = 60
20
+ self.window_size = 60 # seconds
21
+
22
+ # Connect to Redis
23
+ try:
24
+ self.redis_client = redis.from_url(settings.redis_url)
25
+ except Exception as e:
26
+ print(f"⚠️ Rate limiter disabled: Could not connect to Redis ({e})")
27
+ self.redis_client = None
28
+
29
+ async def __call__(self, scope: Scope, receive: Receive, send: Send):
30
+ # Skip if not HTTP
31
+ if scope["type"] != "http":
32
+ await self.app(scope, receive, send)
33
+ return
34
+
35
+ # Skip rate limiting for non-API routes or if Redis is down
36
+ path = scope.get("path", "")
37
+ if not path.startswith("/api/") or self.redis_client is None:
38
+ await self.app(scope, receive, send)
39
+ return
40
+
41
+ # Get client IP
42
+ client = scope.get("client")
43
+ client_ip = client[0] if client else "unknown"
44
+ key = f"rate_limit:{client_ip}"
45
+
46
+ try:
47
+ # Simple fixed window counter
48
+ current_count = self.redis_client.incr(key)
49
+
50
+ # Set expiry on first request
51
+ if current_count == 1:
52
+ self.redis_client.expire(key, self.window_size)
53
+
54
+ if current_count > self.requests_per_minute:
55
+ response = JSONResponse(
56
+ status_code=429,
57
+ content={
58
+ "detail": "Too many requests",
59
+ "retry_after": self.window_size
60
+ },
61
+ headers={"Retry-After": str(self.window_size)}
62
+ )
63
+ await response(scope, receive, send)
64
+ return
65
+
66
+ except redis.RedisError:
67
+ # Fail open if Redis has issues during request
68
+ pass
69
+
70
+ await self.app(scope, receive, send)
backend/app/core/request_size_middleware.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Request Size Limiting Middleware
3
+ Prevents large request bodies from consuming excessive memory.
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from starlette.middleware.base import BaseHTTPMiddleware
9
+ from starlette.requests import Request
10
+ from starlette.responses import Response
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Default max body size: 50MB (configurable via env)
15
+ DEFAULT_MAX_BODY_SIZE = 50 * 1024 * 1024 # 50 MB
16
+
17
+
18
+ class RequestSizeLimitMiddleware(BaseHTTPMiddleware):
19
+ """
20
+ ASGI Middleware to limit request body size.
21
+
22
+ Checks Content-Length header and rejects requests that exceed the limit
23
+ with a 413 Payload Too Large response.
24
+
25
+ Note: This checks Content-Length before reading the body.
26
+ For chunked transfer encoding, consider additional safeguards.
27
+ """
28
+
29
+ def __init__(self, app, max_body_size: int = None):
30
+ super().__init__(app)
31
+ self.max_body_size = max_body_size or int(
32
+ os.getenv("MAX_REQUEST_BODY_SIZE", DEFAULT_MAX_BODY_SIZE)
33
+ )
34
+ logger.info(f"Request size limit: {self.max_body_size / 1024 / 1024:.1f} MB")
35
+
36
+ async def dispatch(self, request: Request, call_next):
37
+ # Skip for WebSocket upgrades
38
+ if request.headers.get("upgrade", "").lower() == "websocket":
39
+ return await call_next(request)
40
+
41
+ # Check Content-Length header
42
+ content_length = request.headers.get("content-length")
43
+
44
+ if content_length:
45
+ try:
46
+ size = int(content_length)
47
+ if size > self.max_body_size:
48
+ logger.warning(
49
+ f"Request too large: {size / 1024 / 1024:.1f} MB "
50
+ f"(limit: {self.max_body_size / 1024 / 1024:.1f} MB) "
51
+ f"from {request.client.host if request.client else 'unknown'}"
52
+ )
53
+ return Response(
54
+ content="Request body too large",
55
+ status_code=413,
56
+ media_type="text/plain"
57
+ )
58
+ except ValueError:
59
+ pass # Invalid Content-Length, let the server handle it
60
+
61
+ return await call_next(request)
62
+
63
+
64
+ class StreamingSizeValidator:
65
+ """
66
+ Utility for validating file upload sizes during streaming read.
67
+
68
+ Use with SpooledTemporaryFile for memory-efficient large file handling:
69
+
70
+ validator = StreamingSizeValidator(max_size=100 * 1024 * 1024)
71
+ with SpooledTemporaryFile(max_size=5*1024*1024) as tmp:
72
+ async for chunk in file.stream():
73
+ validator.add(len(chunk)) # Raises if exceeded
74
+ tmp.write(chunk)
75
+ """
76
+
77
+ def __init__(self, max_size: int):
78
+ self.max_size = max_size
79
+ self.current_size = 0
80
+
81
+ def add(self, chunk_size: int):
82
+ """Add chunk size and check limit. Raises ValueError if exceeded."""
83
+ self.current_size += chunk_size
84
+ if self.current_size > self.max_size:
85
+ raise ValueError(
86
+ f"Upload exceeds size limit: {self.current_size} > {self.max_size}"
87
+ )
88
+
89
+ @property
90
+ def size(self) -> int:
91
+ return self.current_size
backend/app/core/security.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Security Utilities
3
+ Handles password hashing, JWT generation, and API key verification.
4
+ """
5
+
6
+ from datetime import datetime, timedelta
7
+ from typing import Optional, Union, Any
8
+ from jose import jwt
9
+ from passlib.context import CryptContext
10
+ from fastapi.security import OAuth2PasswordBearer, APIKeyHeader
11
+ from fastapi import Depends, HTTPException, status
12
+ from sqlalchemy.orm import Session
13
+
14
+ from ..core.config import get_settings
15
+ from ..models import get_db, User, ApiKey
16
+
17
+ settings = get_settings()
18
+
19
+ # Password hashing (PBKDF2 is safer/easier on Windows than bcrypt sometimes)
20
+ pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
21
+
22
+ # JWT configuration
23
+ SECRET_KEY = settings.secret_key
24
+ ALGORITHM = settings.algorithm
25
+ ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
26
+
27
+ # OAuth2 scheme - auto_error=False allows API key fallback
28
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/v1/auth/login", auto_error=False)
29
+ api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
30
+
31
+
32
+ def verify_password(plain_password: str, hashed_password: str) -> bool:
33
+ return pwd_context.verify(plain_password, hashed_password)
34
+
35
+ def get_password_hash(password: str) -> str:
36
+ return pwd_context.hash(password)
37
+
38
+ def create_access_token(subject: Union[str, Any], expires_delta: timedelta = None) -> str:
39
+ if expires_delta:
40
+ expire = datetime.utcnow() + expires_delta
41
+ else:
42
+ expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
43
+
44
+ to_encode = {"exp": expire, "sub": str(subject)}
45
+ encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
46
+ return encoded_jwt
47
+
48
+ async def get_current_user(
49
+ token: str = Depends(oauth2_scheme),
50
+ db: Session = Depends(get_db)
51
+ ) -> Optional[User]:
52
+ """Validate JWT and return user. Returns None if token missing/invalid."""
53
+ if not token:
54
+ return None # Allow API key fallback
55
+
56
+ try:
57
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
58
+ user_id: str = payload.get("sub")
59
+ if user_id is None:
60
+ return None
61
+ except Exception:
62
+ return None
63
+
64
+ user = db.query(User).filter(User.id == int(user_id)).first()
65
+ return user
66
+
67
+ async def get_current_active_user(current_user: Optional[User] = Depends(get_current_user)) -> User:
68
+ """Get current active user. Raises 401 if not authenticated."""
69
+ if not current_user:
70
+ raise HTTPException(
71
+ status_code=status.HTTP_401_UNAUTHORIZED,
72
+ detail="Not authenticated",
73
+ headers={"WWW-Authenticate": "Bearer"},
74
+ )
75
+ if not current_user.is_active:
76
+ raise HTTPException(status_code=400, detail="Inactive user")
77
+ return current_user
78
+
79
+ async def verify_api_key(
80
+ api_key: str = Depends(api_key_header),
81
+ db: Session = Depends(get_db)
82
+ ) -> Optional[User]:
83
+ """
84
+ Validate API key from X-API-Key header.
85
+ Returns the associated user if valid, else None (or raises if enforcing).
86
+ """
87
+ if not api_key:
88
+ return None # Or raise if strict
89
+
90
+ key_record = db.query(ApiKey).filter(ApiKey.key == api_key, ApiKey.is_active == True).first()
91
+
92
+ if key_record:
93
+ # Update usage stats
94
+ key_record.last_used_at = datetime.utcnow()
95
+ db.commit()
96
+ return key_record.user
97
+
98
+ return None # Invalid key
99
+
100
+ def get_api_user_or_jwt_user(
101
+ api_key_user: Optional[User] = Depends(verify_api_key),
102
+ jwt_user: Optional[User] = Depends(get_current_user)
103
+ ) -> User:
104
+ """Allow access via either API Key or JWT"""
105
+ if api_key_user:
106
+ return api_key_user
107
+ if jwt_user:
108
+ return jwt_user
109
+
110
+ raise HTTPException(
111
+ status_code=status.HTTP_401_UNAUTHORIZED,
112
+ detail="Not authenticated"
113
+ )
backend/app/core/security_encryption.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Field-level Encryption for SQLAlchemy Models.
3
+
4
+ Uses Fernet symmetric encryption from the `cryptography` library.
5
+ The ENCRYPTION_KEY should be a 32-byte base64-encoded key.
6
+ Generate one with: from cryptography.fernet import Fernet; print(Fernet.generate_key())
7
+ """
8
+
9
+ import os
10
+ import base64
11
+ import logging
12
+ from typing import Optional
13
+
14
+ from cryptography.fernet import Fernet, InvalidToken
15
+ from sqlalchemy import TypeDecorator, String
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # --- Configuration ---
20
+ # IMPORTANT: Store this securely! In production, use secrets manager or env vars.
21
+ # Default key is for development ONLY - regenerate for production!
22
+ _DEFAULT_DEV_KEY = "VOICEFORGE_DEV_KEY_REPLACE_ME_NOW=" # Placeholder - NOT a valid key
23
+
24
+ def _get_encryption_key() -> bytes:
25
+ """Get the encryption key from environment. Fail-closed in production."""
26
+ key_str = os.getenv("ENCRYPTION_KEY")
27
+
28
+ if key_str:
29
+ return key_str.encode()
30
+
31
+ # Check if running in production
32
+ is_production = os.getenv("ENVIRONMENT", "development").lower() == "production"
33
+ if is_production:
34
+ raise RuntimeError(
35
+ "ENCRYPTION_KEY environment variable must be set in production! "
36
+ "Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
37
+ )
38
+
39
+ # Development fallback - deterministic but INSECURE
40
+ logger.warning("⚠️ ENCRYPTION_KEY not set! Using DEV-ONLY fixed key. DO NOT USE IN PRODUCTION.")
41
+ # Fixed base64 key for dev consistency (32 bytes -> valid Fernet key)
42
+ return base64.urlsafe_b64encode(b"voiceforge_dev_key_32bytes_ok!")
43
+
44
+ # Cache the Fernet instance
45
+ _fernet: Optional[Fernet] = None
46
+
47
+ def get_fernet() -> Fernet:
48
+ """Get or create the Fernet encryption instance."""
49
+ global _fernet
50
+ if _fernet is None:
51
+ key = _get_encryption_key()
52
+ _fernet = Fernet(key)
53
+ return _fernet
54
+
55
+
56
+ # --- SQLAlchemy TypeDecorator ---
57
+
58
+ class EncryptedString(TypeDecorator):
59
+ """
60
+ SQLAlchemy type that encrypts/decrypts string values transparently.
61
+
62
+ Usage:
63
+ class User(Base):
64
+ full_name = Column(EncryptedString(255), nullable=True)
65
+
66
+ The encrypted data is stored as a base64-encoded string in the database.
67
+ """
68
+ impl = String
69
+ cache_ok = True
70
+
71
+ def __init__(self, length: int = 512, *args, **kwargs):
72
+ # Encrypted strings are longer than plaintext, so pad the length
73
+ super().__init__(length * 2, *args, **kwargs)
74
+
75
+ def process_bind_param(self, value, dialect):
76
+ """Encrypt the value before storing in DB."""
77
+ if value is None:
78
+ return None
79
+
80
+ try:
81
+ fernet = get_fernet()
82
+ # Encode string to bytes, encrypt, then decode to string for storage
83
+ encrypted = fernet.encrypt(value.encode('utf-8'))
84
+ return encrypted.decode('utf-8')
85
+ except Exception as e:
86
+ logger.error(f"Encryption failed: {e}")
87
+ # In case of encryption failure, store plaintext (fail-open for dev)
88
+ # In production, you might want to raise instead
89
+ return value
90
+
91
+ def process_result_value(self, value, dialect):
92
+ """Decrypt the value when reading from DB."""
93
+ if value is None:
94
+ return None
95
+
96
+ try:
97
+ fernet = get_fernet()
98
+ # Decode from storage string, decrypt, then decode to string
99
+ decrypted = fernet.decrypt(value.encode('utf-8'))
100
+ return decrypted.decode('utf-8')
101
+ except InvalidToken:
102
+ # Value might be plaintext (legacy data or encryption disabled)
103
+ logger.warning("Decryption failed - returning raw value (possible legacy data)")
104
+ return value
105
+ except Exception as e:
106
+ logger.error(f"Decryption failed: {e}")
107
+ return value
backend/app/core/security_headers.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from starlette.middleware.base import BaseHTTPMiddleware
2
+ from starlette.types import ASGIApp, Receive, Scope, Send
3
+
4
+ class SecurityHeadersMiddleware(BaseHTTPMiddleware):
5
+ def __init__(self, app: ASGIApp):
6
+ super().__init__(app)
7
+
8
+ async def dispatch(self, request, call_next):
9
+ response = await call_next(request)
10
+
11
+ # Prevent Clickjacking
12
+ response.headers["X-Frame-Options"] = "DENY"
13
+
14
+ # Prevent MIME type sniffing
15
+ response.headers["X-Content-Type-Options"] = "nosniff"
16
+
17
+ # Enable XSS filtering in browser (legacy but good for depth)
18
+ response.headers["X-XSS-Protection"] = "1; mode=block"
19
+
20
+ # Strict Transport Security (HSTS)
21
+ # Enforce HTTPS. max-age=31536000 is 1 year.
22
+ # includeSubDomains applies to all subdomains.
23
+ # preload allows domain to be included in browser preload lists.
24
+ # NOTE: Only effective if served over HTTPS.
25
+ response.headers["Strict-Transport-Security"] = "max-age=31536000; includeSubDomains"
26
+
27
+ # Content Security Policy (CSP)
28
+ # Very strict default: only allow content from self.
29
+ # This might need adjustment for Swagger UI (CDN assets) or other resources.
30
+ # For now, we allow 'unsafe-inline' and 'unsafe-eval' for Swagger UI compatibility if needed,
31
+ # but primarily 'self'.
32
+ response.headers["Content-Security-Policy"] = "default-src 'self'; img-src 'self' data: https:; style-src 'self' 'unsafe-inline'; script-src 'self' 'unsafe-inline';"
33
+
34
+ # Referrer Policy
35
+ response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
36
+
37
+ return response
backend/app/core/ws_security.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebSocket Security Utilities
3
+ Authentication and validation for WebSocket connections.
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from typing import Optional, Set
9
+ from urllib.parse import urlparse
10
+
11
+ from fastapi import WebSocket, WebSocketException, status
12
+ from jose import jwt, JWTError
13
+
14
+ from .config import get_settings
15
+
16
+ logger = logging.getLogger(__name__)
17
+ settings = get_settings()
18
+
19
+ # Allowed origins for WebSocket connections
20
+ _allowed_ws_origins: Optional[Set[str]] = None
21
+
22
+
23
+ def get_allowed_origins() -> Set[str]:
24
+ """Get set of allowed WebSocket origins."""
25
+ global _allowed_ws_origins
26
+ if _allowed_ws_origins is None:
27
+ origins_str = os.getenv(
28
+ "CORS_ORIGINS",
29
+ "http://localhost:8501,http://localhost:3000,http://localhost:8000"
30
+ )
31
+ _allowed_ws_origins = {o.strip().rstrip('/') for o in origins_str.split(",")}
32
+
33
+ # Add HuggingFace origins if deploying there
34
+ if os.getenv("SPACE_ID"):
35
+ _allowed_ws_origins.add("https://huggingface.co")
36
+ _allowed_ws_origins.add(f"https://{os.getenv('SPACE_ID')}.hf.space")
37
+
38
+ logger.info(f"WebSocket allowed origins: {_allowed_ws_origins}")
39
+
40
+ return _allowed_ws_origins
41
+
42
+
43
+ def validate_ws_origin(websocket: WebSocket) -> bool:
44
+ """
45
+ Validate WebSocket Origin header against allowed origins.
46
+ Returns True if valid, False otherwise.
47
+ """
48
+ origin = websocket.headers.get("origin")
49
+
50
+ if not origin:
51
+ # No origin header - could be same-origin or non-browser client
52
+ # In production, you might want to reject these
53
+ logger.warning("WebSocket connection without Origin header")
54
+ return True # Allow for dev/non-browser clients
55
+
56
+ # Normalize origin
57
+ origin = origin.rstrip('/')
58
+ allowed = get_allowed_origins()
59
+
60
+ if origin in allowed:
61
+ return True
62
+
63
+ # Check for wildcard subdomain match (e.g., *.hf.space)
64
+ parsed = urlparse(origin)
65
+ if parsed.netloc.endswith('.hf.space'):
66
+ return True
67
+
68
+ logger.warning(f"WebSocket rejected: origin '{origin}' not in allowed list")
69
+ return False
70
+
71
+
72
+ async def authenticate_websocket(
73
+ websocket: WebSocket,
74
+ token: Optional[str] = None
75
+ ) -> Optional[int]:
76
+ """
77
+ Authenticate a WebSocket connection using JWT token.
78
+
79
+ Token can be provided via:
80
+ - Query parameter: ws://host/ws/endpoint?token=xxx
81
+ - First message after connection
82
+
83
+ Returns user_id if authenticated, None otherwise.
84
+ """
85
+ # Try query param first
86
+ if not token:
87
+ token = websocket.query_params.get("token")
88
+
89
+ if not token:
90
+ return None
91
+
92
+ try:
93
+ payload = jwt.decode(
94
+ token,
95
+ settings.secret_key,
96
+ algorithms=[settings.algorithm]
97
+ )
98
+ user_id = payload.get("sub")
99
+ if user_id:
100
+ return int(user_id)
101
+ except JWTError as e:
102
+ logger.warning(f"WebSocket auth failed: {e}")
103
+ except Exception as e:
104
+ logger.error(f"WebSocket auth error: {e}")
105
+
106
+ return None
107
+
108
+
109
+ async def require_ws_auth(websocket: WebSocket) -> int:
110
+ """
111
+ Require authentication for WebSocket. Closes connection if not authenticated.
112
+
113
+ Usage:
114
+ @router.websocket("/secure/{client_id}")
115
+ async def secure_ws(websocket: WebSocket, client_id: str):
116
+ user_id = await require_ws_auth(websocket)
117
+ await websocket.accept()
118
+ # ... handle connection
119
+ """
120
+ # Validate origin first
121
+ if not validate_ws_origin(websocket):
122
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
123
+ raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
124
+
125
+ user_id = await authenticate_websocket(websocket)
126
+ if not user_id:
127
+ await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
128
+ raise WebSocketException(code=status.WS_1008_POLICY_VIOLATION)
129
+
130
+ return user_id
131
+
132
+
133
+ class WebSocketRateLimiter:
134
+ """
135
+ Simple rate limiter for WebSocket messages.
136
+ Tracks message count per connection and rejects if exceeded.
137
+ """
138
+
139
+ def __init__(self, max_messages_per_second: int = 10, max_message_size: int = 1024 * 1024):
140
+ self.max_rate = max_messages_per_second
141
+ self.max_size = max_message_size
142
+ self._counts: dict = {} # client_id -> (count, last_reset_time)
143
+
144
+ def check_rate(self, client_id: str) -> bool:
145
+ """Check if client is within rate limits. Returns True if allowed."""
146
+ import time
147
+ now = time.time()
148
+
149
+ if client_id not in self._counts:
150
+ self._counts[client_id] = (1, now)
151
+ return True
152
+
153
+ count, last_reset = self._counts[client_id]
154
+
155
+ # Reset counter every second
156
+ if now - last_reset >= 1.0:
157
+ self._counts[client_id] = (1, now)
158
+ return True
159
+
160
+ # Check limit
161
+ if count >= self.max_rate:
162
+ logger.warning(f"WebSocket rate limit exceeded for {client_id}")
163
+ return False
164
+
165
+ self._counts[client_id] = (count + 1, last_reset)
166
+ return True
167
+
168
+ def check_size(self, data: bytes) -> bool:
169
+ """Check if message size is within limits."""
170
+ if len(data) > self.max_size:
171
+ logger.warning(f"WebSocket message too large: {len(data)} > {self.max_size}")
172
+ return False
173
+ return True
174
+
175
+ def cleanup(self, client_id: str):
176
+ """Remove client from tracking."""
177
+ self._counts.pop(client_id, None)
178
+
179
+
180
+ # Global rate limiter instance
181
+ ws_rate_limiter = WebSocketRateLimiter(max_messages_per_second=20, max_message_size=5 * 1024 * 1024)
backend/app/main.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceForge - FastAPI Main Application
3
+ Production-grade Speech-to-Text & Text-to-Speech API
4
+ """
5
+
6
+ import logging
7
+ # WARN: PyTorch 2.6+ security workaround for Pyannote
8
+ # Must be before any other torch imports
9
+ import os
10
+ os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
11
+ import torch.serialization
12
+ try:
13
+ torch.serialization.add_safe_globals([dict])
14
+ except:
15
+ pass
16
+
17
+ from contextlib import asynccontextmanager
18
+ from fastapi import FastAPI, Request
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import JSONResponse
21
+ from fastapi.openapi.utils import get_openapi
22
+
23
+ from prometheus_fastapi_instrumentator import Instrumentator
24
+ from .core.config import get_settings
25
+ from .api.routes import (
26
+ stt_router,
27
+ tts_router,
28
+ health_router,
29
+ transcripts_router,
30
+ ws_router,
31
+ translation_router,
32
+ batch_router,
33
+ analysis_router,
34
+ audio_router,
35
+ cloning_router,
36
+ sign_router,
37
+ auth_router,
38
+ s2s_router,
39
+ sign_bridge # Import the module
40
+ )
41
+ from .models import Base, engine
42
+
43
+
44
+
45
+ # Configure logging
46
+ logging.basicConfig(
47
+ level=logging.INFO,
48
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
49
+ )
50
+ logger = logging.getLogger(__name__)
51
+
52
+ settings = get_settings()
53
+
54
+
55
+ @asynccontextmanager
56
+ async def lifespan(app: FastAPI):
57
+ """
58
+ Application lifespan handler
59
+ Runs on startup and shutdown
60
+ """
61
+ # Startup
62
+ logger.info(f"Starting {settings.app_name} v{settings.app_version}")
63
+
64
+ # Create database tables
65
+ logger.info("Creating database tables...")
66
+ Base.metadata.create_all(bind=engine)
67
+
68
+ # Pre-warm Whisper models for faster first request
69
+ logger.info("Pre-warming AI models...")
70
+ try:
71
+ from .services.whisper_stt_service import get_whisper_model
72
+ # Pre-load English Distil model (most common)
73
+ get_whisper_model("distil-small.en")
74
+ logger.info("✅ Distil-Whisper model loaded")
75
+ # Pre-load multilingual model
76
+ get_whisper_model("small")
77
+ logger.info("✅ Whisper-small model loaded")
78
+ except Exception as e:
79
+ logger.warning(f"Model pre-warming failed: {e}")
80
+
81
+ # Pre-cache TTS voice list
82
+ try:
83
+ from .services.tts_service import get_tts_service
84
+ tts_service = get_tts_service()
85
+ await tts_service.get_voices()
86
+ logger.info("✅ TTS voice list cached")
87
+ except Exception as e:
88
+ logger.warning(f"Voice list caching failed: {e}")
89
+
90
+ logger.info("🚀 Startup complete - All models warmed up!")
91
+
92
+ yield
93
+
94
+ # Shutdown
95
+ logger.info("Shutting down...")
96
+ # TODO: Close database connections
97
+ # TODO: Close Redis connections
98
+ logger.info("Shutdown complete")
99
+
100
+
101
+ # Create FastAPI application
102
+ app = FastAPI(
103
+ title=settings.app_name,
104
+ description="""
105
+ ## VoiceForge API
106
+
107
+ Production-grade Speech-to-Text and Text-to-Speech API.
108
+
109
+ ### Features
110
+
111
+ - 🎤 **Speech-to-Text**: Transcribe audio files with word-level timestamps
112
+ - 🔊 **Text-to-Speech**: Synthesize speech with 300+ neural voices
113
+ - 🌍 **Multi-language**: Support for 10+ languages
114
+ - 🧠 **AI Analysis**: Sentiment, keywords, and summarization
115
+ - 🌐 **Translation**: Translate text/audio between 20+ languages
116
+ - ⚡ **Free & Fast**: Local Whisper + Edge TTS - no API costs
117
+ """,
118
+ version=settings.app_version,
119
+ docs_url="/docs",
120
+ redoc_url="/redoc",
121
+ lifespan=lifespan,
122
+ )
123
+
124
+
125
+ from slowapi import _rate_limit_exceeded_handler
126
+ from slowapi.errors import RateLimitExceeded
127
+ from slowapi.middleware import SlowAPIMiddleware
128
+ from .core.limiter import limiter
129
+ from .core.security_headers import SecurityHeadersMiddleware
130
+ from .core.request_size_middleware import RequestSizeLimitMiddleware
131
+
132
+ # Request body size limit (must be first to reject large requests early)
133
+ app.add_middleware(RequestSizeLimitMiddleware)
134
+
135
+ # Add Rate Limiting (default: 60 requests/min per IP)
136
+ app.state.limiter = limiter
137
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
138
+ app.add_middleware(SlowAPIMiddleware)
139
+
140
+ # Security Headers (Must be before CORS to ensure headers are present even on errors/CORS blocks)
141
+ app.add_middleware(SecurityHeadersMiddleware)
142
+
143
+ # CORS middleware - Use CORS_ORIGINS env var (comma-separated) or defaults
144
+ _cors_origins = os.getenv(
145
+ "CORS_ORIGINS",
146
+ "http://localhost:8501,http://localhost:3000,http://localhost:8000"
147
+ ).split(",")
148
+ # Add HuggingFace origins if deploying there
149
+ if os.getenv("SPACE_ID"): # HF Spaces sets this
150
+ _cors_origins.extend(["https://huggingface.co", f"https://{os.getenv('SPACE_ID')}.hf.space"])
151
+
152
+ app.add_middleware(
153
+ CORSMiddleware,
154
+ allow_origins=[o.strip() for o in _cors_origins],
155
+ allow_credentials=True,
156
+ allow_methods=["*"],
157
+ allow_headers=["*"],
158
+ )
159
+
160
+ # Prometheus Metrics
161
+ Instrumentator().instrument(app).expose(app)
162
+
163
+
164
+ # Include routers
165
+ app.include_router(health_router)
166
+ app.include_router(auth_router, prefix="/api/v1")
167
+ app.include_router(stt_router, prefix="/api/v1")
168
+ app.include_router(tts_router, prefix="/api/v1")
169
+ app.include_router(transcripts_router, prefix="/api/v1")
170
+ app.include_router(ws_router, prefix="/api/v1")
171
+ app.include_router(translation_router, prefix="/api/v1")
172
+ app.include_router(batch_router, prefix="/api/v1")
173
+ app.include_router(analysis_router, prefix="/api/v1")
174
+ app.include_router(audio_router, prefix="/api/v1")
175
+ app.include_router(cloning_router, prefix="/api/v1")
176
+ app.include_router(sign_router, prefix="/api/v1")
177
+ app.include_router(s2s_router, prefix="/api/v1") # Added s2s_router
178
+ app.include_router(sign_bridge.router, prefix="/api/v1") # Added sign_bridge_router
179
+
180
+
181
+
182
+
183
+
184
+ # Exception handlers
185
+ @app.exception_handler(Exception)
186
+ async def global_exception_handler(request: Request, exc: Exception):
187
+ """Global exception handler for unhandled errors"""
188
+ logger.exception(f"Unhandled error: {exc}")
189
+ return JSONResponse(
190
+ status_code=500,
191
+ content={
192
+ "error": "internal_server_error",
193
+ "message": "An unexpected error occurred",
194
+ "detail": str(exc) if settings.debug else None,
195
+ },
196
+ )
197
+
198
+
199
+ @app.exception_handler(ValueError)
200
+ async def value_error_handler(request: Request, exc: ValueError):
201
+ """Handler for validation errors"""
202
+ return JSONResponse(
203
+ status_code=400,
204
+ content={
205
+ "error": "validation_error",
206
+ "message": str(exc),
207
+ },
208
+ )
209
+
210
+
211
+ # Root endpoint
212
+ @app.get("/", tags=["Root"])
213
+ async def root():
214
+ """API root - returns basic info"""
215
+ return {
216
+ "name": settings.app_name,
217
+ "version": settings.app_version,
218
+ "status": "running",
219
+ "docs": "/docs",
220
+ "health": "/health",
221
+ }
222
+
223
+
224
+ # Custom OpenAPI schema
225
+ def custom_openapi():
226
+ """Generate custom OpenAPI schema with enhanced documentation"""
227
+ if app.openapi_schema:
228
+ return app.openapi_schema
229
+
230
+ openapi_schema = get_openapi(
231
+ title=settings.app_name,
232
+ version=settings.app_version,
233
+ description=app.description,
234
+ routes=app.routes,
235
+ )
236
+
237
+ # Add custom logo
238
+ openapi_schema["info"]["x-logo"] = {
239
+ "url": "https://example.com/logo.png"
240
+ }
241
+
242
+ # Add tags with descriptions
243
+ openapi_schema["tags"] = [
244
+ {
245
+ "name": "Health",
246
+ "description": "Health check endpoints for monitoring",
247
+ },
248
+ {
249
+ "name": "Speech-to-Text",
250
+ "description": "Convert audio to text with timestamps and speaker detection",
251
+ },
252
+ {
253
+ "name": "Text-to-Speech",
254
+ "description": "Convert text to natural-sounding speech",
255
+ },
256
+ ]
257
+
258
+ app.openapi_schema = openapi_schema
259
+ return app.openapi_schema
260
+
261
+
262
+ app.openapi = custom_openapi
263
+
264
+
265
+ if __name__ == "__main__":
266
+ import uvicorn
267
+
268
+ uvicorn.run(
269
+ "app.main:app",
270
+ host=settings.api_host,
271
+ port=settings.api_port,
272
+ reload=settings.debug,
273
+ )
backend/app/models/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceForge Database Models Package
3
+ """
4
+
5
+ from .base import Base, engine, SessionLocal, get_db
6
+ from .audio_file import AudioFile
7
+ from .transcript import Transcript
8
+ from .auth import User, ApiKey
9
+
10
+ __all__ = [
11
+ "Base",
12
+ "engine",
13
+ "SessionLocal",
14
+ "get_db",
15
+ "AudioFile",
16
+ "Transcript",
17
+ "User",
18
+ "ApiKey",
19
+ ]
backend/app/models/audio_file.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio File Model
3
+ """
4
+
5
+ from datetime import datetime
6
+ from sqlalchemy import Column, Integer, String, Float, DateTime, ForeignKey, Enum
7
+ from sqlalchemy.orm import relationship
8
+ import enum
9
+
10
+ from .base import Base
11
+
12
+
13
+ class AudioFileStatus(str, enum.Enum):
14
+ """Audio file processing status"""
15
+ UPLOADED = "uploaded"
16
+ PROCESSING = "processing"
17
+ DONE = "done"
18
+ FAILED = "failed"
19
+
20
+
21
+ class AudioFile(Base):
22
+ """Audio file database model"""
23
+
24
+ __tablename__ = "audio_files"
25
+
26
+ id = Column(Integer, primary_key=True, index=True)
27
+ # user_id removed
28
+ storage_path = Column(String(500), nullable=False)
29
+ original_filename = Column(String(255), nullable=True)
30
+ duration = Column(Float, nullable=True) # Duration in seconds
31
+ format = Column(String(20), nullable=True) # wav, mp3, etc.
32
+ sample_rate = Column(Integer, nullable=True)
33
+ channels = Column(Integer, nullable=True)
34
+ file_size = Column(Integer, nullable=True) # Size in bytes
35
+ language = Column(String(10), nullable=True) # User-specified language
36
+ detected_language = Column(String(10), nullable=True) # Auto-detected language
37
+ status = Column(String(20), default=AudioFileStatus.UPLOADED.value, index=True)
38
+ error_message = Column(String(500), nullable=True)
39
+ created_at = Column(DateTime, default=datetime.utcnow, index=True)
40
+ processed_at = Column(DateTime, nullable=True)
41
+
42
+ # Relationships
43
+ # user relationship removed
44
+ transcripts = relationship("Transcript", back_populates="audio_file")
45
+
46
+ def __repr__(self):
47
+ return f"<AudioFile(id={self.id}, filename={self.original_filename}, status={self.status})>"
backend/app/models/auth.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ User and API Key Models
3
+ """
4
+
5
+ from sqlalchemy import Column, Integer, String, Boolean, ForeignKey, DateTime
6
+ from sqlalchemy.orm import relationship
7
+ from datetime import datetime
8
+ from .base import Base
9
+ from ..core.security_encryption import EncryptedString
10
+
11
+ class User(Base):
12
+ __tablename__ = "users"
13
+
14
+ id = Column(Integer, primary_key=True, index=True)
15
+ email = Column(String, unique=True, index=True, nullable=False) # Not encrypted (needed for lookup)
16
+ hashed_password = Column(String, nullable=False)
17
+ full_name = Column(EncryptedString(255), nullable=True) # ENCRYPTED
18
+ is_active = Column(Boolean, default=True)
19
+ is_superuser = Column(Boolean, default=False)
20
+
21
+ # Relationships
22
+ api_keys = relationship("ApiKey", back_populates="user", cascade="all, delete-orphan")
23
+
24
+
25
+ class ApiKey(Base):
26
+ __tablename__ = "api_keys"
27
+
28
+ id = Column(Integer, primary_key=True, index=True)
29
+ key = Column(String, unique=True, index=True, nullable=False)
30
+ name = Column(String, nullable=True) # e.g. "Production App"
31
+ is_active = Column(Boolean, default=True)
32
+ created_at = Column(DateTime, default=datetime.utcnow)
33
+ last_used_at = Column(DateTime, nullable=True)
34
+
35
+ user_id = Column(Integer, ForeignKey("users.id"))
36
+ user = relationship("User", back_populates="api_keys")
backend/app/models/base.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SQLAlchemy Base and Database Session
3
+ """
4
+
5
+ from sqlalchemy import create_engine
6
+ from sqlalchemy.ext.declarative import declarative_base
7
+ from sqlalchemy.orm import sessionmaker
8
+
9
+ from ..core.config import get_settings
10
+
11
+ settings = get_settings()
12
+
13
+ # Create SQLAlchemy engine
14
+ # Create SQLAlchemy engine
15
+ if "sqlite" in settings.database_url:
16
+ engine = create_engine(
17
+ settings.database_url,
18
+ connect_args={"check_same_thread": False},
19
+ )
20
+ else:
21
+ engine = create_engine(
22
+ settings.database_url,
23
+ pool_pre_ping=True,
24
+ pool_size=10,
25
+ max_overflow=20,
26
+ )
27
+
28
+ # Create session factory
29
+ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
30
+
31
+ # Create declarative base
32
+ Base = declarative_base()
33
+
34
+
35
+ def get_db():
36
+ """
37
+ Database session dependency for FastAPI
38
+ Yields a database session and ensures cleanup
39
+ """
40
+ db = SessionLocal()
41
+ try:
42
+ yield db
43
+ finally:
44
+ db.close()
backend/app/models/sign_lstm.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ class SignLSTM(nn.Module):
6
+ """
7
+ LSTM Model for Sign Language Recognition.
8
+
9
+ Architecture:
10
+ - Input: Sequence of MediaPipe Landmarks (21 points * 3 coords = 63 features)
11
+ - Hidden Layers: 2 or 3 LSTM layers to capture temporal dynamics
12
+ - Output: Fully Connected layer -> Class probabilities (ASL Alphabet or Vocabulary)
13
+
14
+ Why LSTM?
15
+ It captures the 'motion' and 'context' of signs, not just static hand shapes,
16
+ allowing for much higher accuracy (99% reported in research) and dynamic gesture recognition.
17
+ """
18
+ def __init__(self, input_size=63, hidden_size=128, num_layers=2, num_classes=26):
19
+ super(SignLSTM, self).__init__()
20
+
21
+ self.hidden_size = hidden_size
22
+ self.num_layers = num_layers
23
+
24
+ # LSTM Layer
25
+ # batch_first=True expects input shape: (batch, seq_len, features)
26
+ self.lstm = nn.LSTM(
27
+ input_size,
28
+ hidden_size,
29
+ num_layers,
30
+ batch_first=True,
31
+ dropout=0.2
32
+ )
33
+
34
+ # Fully Connected Layer for classification
35
+ self.fc = nn.Linear(hidden_size, num_classes)
36
+
37
+ # Validation/Inference activation
38
+ self.softmax = nn.Softmax(dim=1)
39
+
40
+ def forward(self, x):
41
+ # x shape: (batch_size, sequence_length, input_size)
42
+
43
+ # Initialize hidden state and cell state (optional, defaults to zeros if not provided)
44
+ # h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
45
+ # c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
46
+
47
+ # Forward propagate LSTM
48
+ # out shape: (batch_size, seq_len, hidden_size)
49
+ out, _ = self.lstm(x)
50
+
51
+ # Decode the hidden state of the last time step
52
+ # out[:, -1, :] gets the last output of the sequence
53
+ out = self.fc(out[:, -1, :])
54
+
55
+ return out
56
+
57
+ def predict(self, x):
58
+ """Helper for inference"""
59
+ with torch.no_grad():
60
+ logits = self.forward(x)
61
+ probabilities = self.softmax(logits)
62
+ confidence, predicted_class = torch.max(probabilities, 1)
63
+ return predicted_class.item(), confidence.item()
backend/app/models/transcript.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transcript Model
3
+ """
4
+
5
+ from datetime import datetime
6
+ from sqlalchemy import Column, Integer, String, Text, DateTime, ForeignKey, JSON, Float
7
+ from sqlalchemy.orm import relationship
8
+
9
+ from .base import Base
10
+ from ..core.security_encryption import EncryptedString
11
+
12
+
13
+ class Transcript(Base):
14
+ """Transcript database model"""
15
+
16
+ __tablename__ = "transcripts"
17
+
18
+ id = Column(Integer, primary_key=True, index=True)
19
+ audio_file_id = Column(Integer, ForeignKey("audio_files.id"), nullable=True, index=True)
20
+ audio_file_id = Column(Integer, ForeignKey("audio_files.id"), nullable=True, index=True)
21
+ # user_id removed (Auth disabled for portfolio)
22
+
23
+ # Transcript content - ENCRYPTED
24
+ raw_text = Column(EncryptedString(10000), nullable=True) # Original transcription
25
+ processed_text = Column(EncryptedString(10000), nullable=True) # After NLP processing
26
+
27
+ # Segments with timestamps and speaker info (JSON array)
28
+ # Format: [{"start": 0.0, "end": 1.5, "text": "Hello", "speaker": "SPEAKER_1", "confidence": 0.95}]
29
+ segments = Column(JSON, nullable=True)
30
+
31
+ # Word-level timestamps (JSON array)
32
+ # Format: [{"word": "hello", "start": 0.0, "end": 0.5, "confidence": 0.98}]
33
+ words = Column(JSON, nullable=True)
34
+
35
+ # Language info
36
+ language = Column(String(10), nullable=True) # Transcription language
37
+ translation_language = Column(String(10), nullable=True) # If translated
38
+ translated_text = Column(Text, nullable=True)
39
+
40
+ # NLP Analysis (Phase 2)
41
+ sentiment = Column(JSON, nullable=True) # {"overall": "positive", "score": 0.8, "segments": [...]}
42
+ topics = Column(JSON, nullable=True) # ["technology", "business"]
43
+ keywords = Column(JSON, nullable=True) # [{"word": "AI", "score": 0.9}]
44
+ action_items = Column(JSON, nullable=True) # [{"text": "Email John", "assignee": "Speaker 1"}]
45
+ attendees = Column(JSON, nullable=True) # ["Speaker 1", "Speaker 2"]
46
+ summary = Column(EncryptedString(5000), nullable=True) # ENCRYPTED
47
+
48
+ # Metadata
49
+ confidence = Column(Float, nullable=True) # Overall confidence score
50
+ duration = Column(Float, nullable=True) # Audio duration in seconds
51
+ word_count = Column(Integer, nullable=True)
52
+
53
+ # Timestamps
54
+ created_at = Column(DateTime, default=datetime.utcnow, index=True)
55
+ updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
56
+
57
+ # Relationships
58
+ audio_file = relationship("AudioFile", back_populates="transcripts")
59
+ audio_file = relationship("AudioFile", back_populates="transcripts")
60
+ # user relationship removed
61
+
62
+ def __repr__(self):
63
+ preview = self.raw_text[:50] + "..." if self.raw_text and len(self.raw_text) > 50 else self.raw_text
64
+ return f"<Transcript(id={self.id}, preview='{preview}')>"
65
+
66
+
67
+ # Import Float for confidence field
backend/app/schemas/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceForge Schemas Package
3
+ """
4
+
5
+ from .stt import (
6
+ TranscriptionRequest,
7
+ TranscriptionResponse,
8
+ TranscriptionSegment,
9
+ TranscriptionWord,
10
+ LanguageInfo,
11
+ )
12
+ from .tts import (
13
+ SynthesisRequest,
14
+ SynthesisResponse,
15
+ VoiceInfo,
16
+ VoiceListResponse,
17
+ )
18
+ from .transcript import (
19
+ TranscriptCreate,
20
+ TranscriptUpdate,
21
+ TranscriptResponse,
22
+ TranscriptListResponse,
23
+ )
24
+
25
+ __all__ = [
26
+ "TranscriptionRequest",
27
+ "TranscriptionResponse",
28
+ "TranscriptionSegment",
29
+ "TranscriptionWord",
30
+ "LanguageInfo",
31
+ "SynthesisRequest",
32
+ "SynthesisResponse",
33
+ "VoiceInfo",
34
+ "VoiceListResponse",
35
+ "TranscriptCreate",
36
+ "TranscriptUpdate",
37
+ "TranscriptResponse",
38
+ "TranscriptListResponse",
39
+ ]
backend/app/schemas/stt.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-Text Schemas
3
+ """
4
+
5
+ from datetime import datetime
6
+ from typing import List, Optional, Dict, Any
7
+ from pydantic import BaseModel, Field
8
+
9
+
10
+ class TranscriptionWord(BaseModel):
11
+ """Individual word with timing information"""
12
+ word: str
13
+ start_time: float = Field(..., description="Start time in seconds")
14
+ end_time: float = Field(..., description="End time in seconds")
15
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score")
16
+
17
+
18
+ class TranscriptionSegment(BaseModel):
19
+ """Transcript segment with speaker and timing"""
20
+ text: str
21
+ start_time: float = Field(..., description="Start time in seconds")
22
+ end_time: float = Field(..., description="End time in seconds")
23
+ speaker: Optional[str] = Field(None, description="Speaker label (e.g., SPEAKER_1)")
24
+ confidence: float = Field(..., ge=0.0, le=1.0)
25
+ words: Optional[List[TranscriptionWord]] = None
26
+
27
+
28
+ class TranscriptionRequest(BaseModel):
29
+ """Request parameters for transcription"""
30
+ language: str = Field(default="en-US", description="Language code (e.g., en-US)")
31
+ enable_automatic_punctuation: bool = True
32
+ enable_word_time_offsets: bool = True
33
+ enable_speaker_diarization: bool = False
34
+ diarization_speaker_count: Optional[int] = Field(None, ge=2, le=10)
35
+ model: str = Field(default="default", description="STT model to use")
36
+
37
+
38
+ class TranscriptionResponse(BaseModel):
39
+ """Response from transcription"""
40
+ id: Optional[int] = None
41
+ audio_file_id: Optional[int] = None
42
+ text: str = Field(..., description="Full transcription text")
43
+ segments: List[TranscriptionSegment] = Field(default_factory=list)
44
+ words: Optional[List[TranscriptionWord]] = None
45
+ language: str
46
+ detected_language: Optional[str] = None
47
+ confidence: float = Field(..., ge=0.0, le=1.0)
48
+ duration: float = Field(..., description="Audio duration in seconds")
49
+ word_count: int
50
+ processing_time: float = Field(..., description="Processing time in seconds")
51
+
52
+ model_config = {
53
+ "from_attributes": True
54
+ }
55
+
56
+
57
+ class StreamingTranscriptionResponse(BaseModel):
58
+ """Response for streaming transcription updates"""
59
+ is_final: bool = False
60
+ text: str
61
+ confidence: float = Field(default=0.0, ge=0.0, le=1.0)
62
+ stability: float = Field(default=0.0, ge=0.0, le=1.0)
63
+
64
+
65
+ class LanguageInfo(BaseModel):
66
+ """Language information for UI display"""
67
+ code: str = Field(..., description="Language code (e.g., en-US)")
68
+ name: str = Field(..., description="Display name (e.g., English (US))")
69
+ native_name: str = Field(..., description="Native name (e.g., English)")
70
+ flag: str = Field(..., description="Flag emoji")
71
+ stt_supported: bool = True
72
+ tts_supported: bool = True
73
+
74
+
75
+ class LanguageListResponse(BaseModel):
76
+ """Response with list of supported languages"""
77
+ languages: List[LanguageInfo]
78
+ total: int
79
+
80
+
81
+
82
+ class TaskStatusResponse(BaseModel):
83
+ """Status of an async transcription task"""
84
+ task_id: str
85
+ status: str = Field(..., description="pending, processing, completed, failed")
86
+ progress: float = Field(default=0.0, ge=0.0, le=100.0, description="Progress percentage")
87
+ result: Optional[TranscriptionResponse] = None
88
+ error: Optional[str] = None
89
+ created_at: datetime
90
+ updated_at: datetime
91
+
92
+
93
+ class AsyncTranscriptionResponse(BaseModel):
94
+ """Response for async transcription submission"""
95
+ task_id: str
96
+ audio_file_id: int
97
+ status: str = "queued"
98
+ message: str = "File uploaded and queued for processing"
backend/app/schemas/transcript.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transcript Schemas
3
+ """
4
+
5
+ from datetime import datetime
6
+ from typing import List, Optional, Dict, Any
7
+ from pydantic import BaseModel, Field
8
+
9
+ from .stt import TranscriptionSegment, TranscriptionWord
10
+
11
+
12
+ class TranscriptCreate(BaseModel):
13
+ """Schema for creating a transcript"""
14
+ raw_text: str
15
+ processed_text: Optional[str] = None
16
+ segments: Optional[List[Dict[str, Any]]] = None
17
+ words: Optional[List[Dict[str, Any]]] = None
18
+ language: str = "en-US"
19
+ confidence: Optional[float] = None
20
+ duration: Optional[float] = None
21
+
22
+
23
+ class TranscriptUpdate(BaseModel):
24
+ """Schema for updating a transcript"""
25
+ processed_text: Optional[str] = None
26
+ language: Optional[str] = None
27
+
28
+
29
+ class TranscriptResponse(BaseModel):
30
+ """Schema for transcript response"""
31
+ id: int
32
+ audio_file_id: Optional[int] = None
33
+ user_id: Optional[int] = None
34
+ raw_text: Optional[str] = None
35
+ processed_text: Optional[str] = None
36
+ segments: Optional[List[Dict[str, Any]]] = None
37
+ words: Optional[List[Dict[str, Any]]] = None
38
+ language: Optional[str] = None
39
+ translation_language: Optional[str] = None
40
+ translated_text: Optional[str] = None
41
+ sentiment: Optional[Dict[str, Any]] = None
42
+ topics: Optional[List[str]] = None
43
+ keywords: Optional[List[Dict[str, Any]]] = None
44
+ summary: Optional[str] = None
45
+ confidence: Optional[float] = None
46
+ duration: Optional[float] = None
47
+ word_count: Optional[int] = None
48
+ created_at: datetime
49
+ updated_at: Optional[datetime] = None
50
+
51
+ model_config = {
52
+ "from_attributes": True
53
+ }
54
+
55
+
56
+ class TranscriptListResponse(BaseModel):
57
+ """Schema for paginated transcript list"""
58
+ transcripts: List[TranscriptResponse]
59
+ total: int
60
+ page: int
61
+ page_size: int
62
+ has_more: bool
63
+
64
+
65
+ class ExportRequest(BaseModel):
66
+ """Schema for transcript export request"""
67
+ format: str = Field(..., pattern="^(txt|srt|vtt|pdf|json)$")
68
+ include_timestamps: bool = True
69
+ include_speakers: bool = True
backend/app/schemas/tts.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text-to-Speech Schemas
3
+ """
4
+
5
+ from typing import List, Optional
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class SynthesisRequest(BaseModel):
10
+ """Request for text-to-speech synthesis"""
11
+ text: str = Field(..., min_length=1, max_length=5000, description="Text to synthesize")
12
+ language: str = Field(default="en-US", description="Language code")
13
+ voice: Optional[str] = Field(None, description="Voice name (e.g., en-US-Wavenet-D)")
14
+
15
+ # Audio configuration
16
+ audio_encoding: str = Field(default="MP3", description="Output format: MP3, LINEAR16, OGG_OPUS")
17
+ sample_rate: int = Field(default=24000, description="Sample rate in Hz")
18
+
19
+ # Voice tuning
20
+ speaking_rate: float = Field(default=1.0, ge=0.25, le=4.0, description="Speaking rate")
21
+ pitch: float = Field(default=0.0, ge=-20.0, le=20.0, description="Voice pitch in semitones")
22
+ volume_gain_db: float = Field(default=0.0, ge=-96.0, le=16.0, description="Volume gain in dB")
23
+
24
+ # SSML support
25
+ use_ssml: bool = Field(default=False, description="Treat text as SSML")
26
+
27
+
28
+ class SynthesisResponse(BaseModel):
29
+ """Response from text-to-speech synthesis"""
30
+ audio_content: str = Field(..., description="Base64 encoded audio")
31
+ audio_size: int = Field(..., description="Audio size in bytes")
32
+ duration_estimate: float = Field(..., description="Estimated duration in seconds")
33
+ voice_used: str
34
+ language: str
35
+ encoding: str
36
+ sample_rate: int
37
+ processing_time: float = Field(..., description="Processing time in seconds")
38
+
39
+
40
+ class VoiceInfo(BaseModel):
41
+ """Information about a TTS voice"""
42
+ name: str = Field(..., description="Voice name (e.g., en-US-Wavenet-D)")
43
+ language_code: str = Field(..., description="Language code")
44
+ language_name: str = Field(..., description="Language display name")
45
+ ssml_gender: str = Field(..., description="MALE, FEMALE, or NEUTRAL")
46
+ natural_sample_rate: int = Field(..., description="Native sample rate in Hz")
47
+ voice_type: str = Field(..., description="Standard, WaveNet, or Neural2")
48
+
49
+ # Display helpers
50
+ display_name: Optional[str] = None
51
+ flag: Optional[str] = None
52
+
53
+
54
+ class VoiceListResponse(BaseModel):
55
+ """Response with list of available voices"""
56
+ voices: List[VoiceInfo]
57
+ total: int
58
+ language_filter: Optional[str] = None
59
+
60
+
61
+ class VoicePreviewRequest(BaseModel):
62
+ """Request for voice preview"""
63
+ voice: str = Field(..., description="Voice name to preview")
64
+ text: Optional[str] = Field(
65
+ default="Hello! This is a preview of my voice.",
66
+ max_length=200
67
+ )
backend/app/services/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceForge Services Package
3
+ """
4
+
5
+ from .stt_service import STTService
6
+ from .tts_service import TTSService
7
+ from .file_service import FileService
8
+
9
+ __all__ = [
10
+ "STTService",
11
+ "TTSService",
12
+ "FileService",
13
+ ]
backend/app/services/audio_service.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio Editing Service
3
+ Handles audio manipulation: Trimming, Merging, and Conversion using Pydub/FFmpeg
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ from typing import List, Optional
9
+ from pydub import AudioSegment
10
+ import tempfile
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class AudioService:
15
+ """
16
+ Service for audio manipulation tasks.
17
+ Requires ffmpeg to be installed/available in path.
18
+ """
19
+
20
+ def __init__(self):
21
+ pass
22
+
23
+ def load_audio(self, file_path: str) -> AudioSegment:
24
+ """Load audio file into Pydub AudioSegment"""
25
+ try:
26
+ return AudioSegment.from_file(file_path)
27
+ except Exception as e:
28
+ logger.error(f"Failed to load audio {file_path}: {e}")
29
+ raise ValueError(f"Could not load audio file: {str(e)}")
30
+
31
+ def trim_audio(self, input_path: str, start_ms: int, end_ms: int, output_path: Optional[str] = None) -> str:
32
+ """
33
+ Trim audio from start_ms to end_ms.
34
+ """
35
+ if start_ms < 0 or end_ms <= start_ms:
36
+ raise ValueError("Invalid start/end timestamps")
37
+
38
+ audio = self.load_audio(input_path)
39
+
40
+ # Check duration
41
+ if start_ms >= len(audio):
42
+ raise ValueError("Start time exceeds audio duration")
43
+
44
+ # Slice
45
+ trimmed = audio[start_ms:end_ms]
46
+
47
+ if not output_path:
48
+ base, ext = os.path.splitext(input_path)
49
+ output_path = f"{base}_trimmed{ext}"
50
+
51
+ trimmed.export(output_path, format=os.path.splitext(output_path)[1][1:])
52
+ logger.info(f"Trimmed audio saved to {output_path}")
53
+ return output_path
54
+
55
+ def merge_audio(self, file_paths: List[str], output_path: str, crossfade_ms: int = 0) -> str:
56
+ """
57
+ Merge multiple audio files into one.
58
+ """
59
+ if not file_paths:
60
+ raise ValueError("No files to merge")
61
+
62
+ combined = AudioSegment.empty()
63
+
64
+ for path in file_paths:
65
+ segment = self.load_audio(path)
66
+ if crossfade_ms > 0 and len(combined) > 0:
67
+ combined = combined.append(segment, crossfade=crossfade_ms)
68
+ else:
69
+ combined += segment
70
+
71
+ # Create dir if needed
72
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
73
+
74
+ # Export
75
+ fmt = os.path.splitext(output_path)[1][1:] or "mp3"
76
+ combined.export(output_path, format=fmt)
77
+ logger.info(f"Merged {len(file_paths)} files to {output_path}")
78
+ return output_path
79
+
80
+ def convert_format(self, input_path: str, target_format: str) -> str:
81
+ """
82
+ Convert audio format (e.g. wav -> mp3)
83
+ """
84
+ audio = self.load_audio(input_path)
85
+
86
+ base = os.path.splitext(input_path)[0]
87
+ output_path = f"{base}.{target_format}"
88
+
89
+ audio.export(output_path, format=target_format)
90
+ logger.info(f"Converted to {target_format}: {output_path}")
91
+ return output_path
92
+
93
+
94
+ # Singleton
95
+ _audio_service = None
96
+
97
+ def get_audio_service() -> AudioService:
98
+ global _audio_service
99
+ if _audio_service is None:
100
+ _audio_service = AudioService()
101
+ return _audio_service
backend/app/services/batch_service.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Batch Processing Service
3
+ Handles multi-file transcription with job tracking and parallel processing
4
+ """
5
+
6
+ import asyncio
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ import uuid
11
+ import zipfile
12
+ from datetime import datetime
13
+ from pathlib import Path
14
+ from typing import Dict, List, Optional, Any
15
+ from dataclasses import dataclass, field
16
+ from enum import Enum
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class JobStatus(str, Enum):
22
+ """Batch job status enum."""
23
+ PENDING = "pending"
24
+ PROCESSING = "processing"
25
+ COMPLETED = "completed"
26
+ FAILED = "failed"
27
+ CANCELLED = "cancelled"
28
+
29
+
30
+ class FileStatus(str, Enum):
31
+ """Individual file status."""
32
+ QUEUED = "queued"
33
+ PROCESSING = "processing"
34
+ COMPLETED = "completed"
35
+ FAILED = "failed"
36
+
37
+
38
+ @dataclass
39
+ class FileResult:
40
+ """Result for a single file in batch."""
41
+ filename: str
42
+ status: FileStatus = FileStatus.QUEUED
43
+ progress: float = 0.0
44
+ transcript: Optional[str] = None
45
+ language: Optional[str] = None
46
+ duration: Optional[float] = None
47
+ word_count: Optional[int] = None
48
+ processing_time: Optional[float] = None
49
+ error: Optional[str] = None
50
+ output_path: Optional[str] = None
51
+
52
+
53
+ @dataclass
54
+ class BatchJob:
55
+ """Batch processing job."""
56
+ job_id: str
57
+ status: JobStatus = JobStatus.PENDING
58
+ created_at: datetime = field(default_factory=datetime.now)
59
+ started_at: Optional[datetime] = None
60
+ completed_at: Optional[datetime] = None
61
+ files: Dict[str, FileResult] = field(default_factory=dict)
62
+ total_files: int = 0
63
+ completed_files: int = 0
64
+ failed_files: int = 0
65
+ options: Dict[str, Any] = field(default_factory=dict)
66
+ output_zip_path: Optional[str] = None
67
+
68
+ @property
69
+ def progress(self) -> float:
70
+ """Overall job progress percentage."""
71
+ if self.total_files == 0:
72
+ return 0.0
73
+ return (self.completed_files + self.failed_files) / self.total_files * 100
74
+
75
+ def to_dict(self) -> Dict[str, Any]:
76
+ """Convert to dictionary for API response."""
77
+ return {
78
+ "job_id": self.job_id,
79
+ "status": self.status.value,
80
+ "progress": round(self.progress, 1),
81
+ "created_at": self.created_at.isoformat(),
82
+ "started_at": self.started_at.isoformat() if self.started_at else None,
83
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
84
+ "total_files": self.total_files,
85
+ "completed_files": self.completed_files,
86
+ "failed_files": self.failed_files,
87
+ "files": {
88
+ name: {
89
+ "filename": f.filename,
90
+ "status": f.status.value,
91
+ "progress": f.progress,
92
+ "transcript": f.transcript[:500] + "..." if f.transcript and len(f.transcript) > 500 else f.transcript,
93
+ "language": f.language,
94
+ "duration": f.duration,
95
+ "word_count": f.word_count,
96
+ "processing_time": f.processing_time,
97
+ "error": f.error,
98
+ }
99
+ for name, f in self.files.items()
100
+ },
101
+ "options": self.options,
102
+ "has_zip": self.output_zip_path is not None,
103
+ }
104
+
105
+
106
+ # In-memory job store (use Redis in production)
107
+ _batch_jobs: Dict[str, BatchJob] = {}
108
+
109
+
110
+ class BatchProcessingService:
111
+ """
112
+ Service for batch audio transcription.
113
+ Processes multiple files with progress tracking.
114
+ """
115
+
116
+ def __init__(self, output_dir: Optional[str] = None):
117
+ """Initialize batch service."""
118
+ self.output_dir = output_dir or tempfile.gettempdir()
119
+ self._processing_lock = asyncio.Lock()
120
+
121
+ def create_job(
122
+ self,
123
+ filenames: List[str],
124
+ options: Optional[Dict[str, Any]] = None,
125
+ ) -> BatchJob:
126
+ """
127
+ Create a new batch job.
128
+
129
+ Args:
130
+ filenames: List of filenames to process
131
+ options: Processing options (language, output_format, etc.)
132
+
133
+ Returns:
134
+ Created BatchJob
135
+ """
136
+ job_id = str(uuid.uuid4())[:8]
137
+
138
+ files = {
139
+ name: FileResult(filename=name)
140
+ for name in filenames
141
+ }
142
+
143
+ job = BatchJob(
144
+ job_id=job_id,
145
+ files=files,
146
+ total_files=len(filenames),
147
+ options=options or {},
148
+ )
149
+
150
+ _batch_jobs[job_id] = job
151
+ logger.info(f"Created batch job {job_id} with {len(filenames)} files")
152
+
153
+ return job
154
+
155
+ def get_job(self, job_id: str) -> Optional[BatchJob]:
156
+ """Get job by ID."""
157
+ return _batch_jobs.get(job_id)
158
+
159
+ def list_jobs(self, limit: int = 20) -> List[BatchJob]:
160
+ """List recent jobs."""
161
+ jobs = list(_batch_jobs.values())
162
+ jobs.sort(key=lambda j: j.created_at, reverse=True)
163
+ return jobs[:limit]
164
+
165
+ async def process_job(
166
+ self,
167
+ job_id: str,
168
+ file_paths: Dict[str, str],
169
+ ) -> BatchJob:
170
+ """
171
+ Process all files in a batch job.
172
+
173
+ Args:
174
+ job_id: Job ID
175
+ file_paths: Mapping of filename -> temp file path
176
+
177
+ Returns:
178
+ Completed BatchJob
179
+ """
180
+ job = self.get_job(job_id)
181
+ if not job:
182
+ raise ValueError(f"Job not found: {job_id}")
183
+
184
+ job.status = JobStatus.PROCESSING
185
+ job.started_at = datetime.now()
186
+
187
+ # STT Service is used inside the worker now
188
+ # from app.services.whisper_stt_service import get_whisper_stt_service
189
+ # stt_service = get_whisper_stt_service()
190
+
191
+ # Get options
192
+ language = job.options.get("language")
193
+ output_format = job.options.get("output_format", "txt")
194
+
195
+ # Process each file
196
+ output_files: List[str] = []
197
+
198
+ for filename, file_path in file_paths.items():
199
+ file_result = job.files.get(filename)
200
+ if not file_result:
201
+ continue
202
+
203
+ file_result.status = FileStatus.PROCESSING
204
+ file_result.progress = 0.0
205
+
206
+ try:
207
+ import time
208
+ start_time = time.time()
209
+
210
+ # Transcribe via Celery Worker
211
+ from app.workers.tasks import transcribe_file_path
212
+
213
+ # Dispatch task
214
+ task = transcribe_file_path.delay(
215
+ file_path=file_path,
216
+ language=language,
217
+ output_format=output_format
218
+ )
219
+
220
+ # Wait for result (since this service runs in background thread)
221
+ # In a full async arch we would return job_id and poll,
222
+ # but here we keep the batch logic simple while scaling the compute.
223
+ task_result = task.get(timeout=600) # 10 min timeout per file
224
+
225
+ processing_time = time.time() - start_time
226
+
227
+ # Update file result
228
+ file_result.transcript = task_result.get("text", "")
229
+ file_result.language = task_result.get("language", "unknown")
230
+ file_result.duration = task_result.get("duration")
231
+ file_result.word_count = len(file_result.transcript.split())
232
+ file_result.processing_time = round(processing_time, 2)
233
+ file_result.status = FileStatus.COMPLETED
234
+ file_result.progress = 100.0
235
+
236
+ # Helper for SRT writing since we have raw segments dicts now
237
+ result = {"segments": task_result.get("segments", []), "text": file_result.transcript}
238
+
239
+ # Save output file
240
+ output_filename = Path(filename).stem + f".{output_format}"
241
+ output_path = os.path.join(self.output_dir, job_id, output_filename)
242
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
243
+
244
+ with open(output_path, "w", encoding="utf-8") as f:
245
+ if output_format == "srt":
246
+ # Write SRT format
247
+ segments = result.get("segments", [])
248
+ for i, seg in enumerate(segments, 1):
249
+ start = self._format_srt_time(seg.get("start", 0))
250
+ end = self._format_srt_time(seg.get("end", 0))
251
+ text = seg.get("text", "").strip()
252
+ f.write(f"{i}\n{start} --> {end}\n{text}\n\n")
253
+ else:
254
+ f.write(file_result.transcript)
255
+
256
+ file_result.output_path = output_path
257
+ output_files.append(output_path)
258
+
259
+ job.completed_files += 1
260
+ logger.info(f"[{job_id}] Completed {filename} ({job.completed_files}/{job.total_files})")
261
+
262
+ except Exception as e:
263
+ file_result.status = FileStatus.FAILED
264
+ file_result.error = str(e)
265
+ file_result.progress = 0.0
266
+ job.failed_files += 1
267
+ logger.error(f"[{job_id}] Failed {filename}: {e}")
268
+
269
+ finally:
270
+ # Clean up temp file
271
+ try:
272
+ if os.path.exists(file_path):
273
+ os.unlink(file_path)
274
+ except:
275
+ pass
276
+
277
+ # Create ZIP of all outputs
278
+ if output_files:
279
+ zip_path = os.path.join(self.output_dir, f"{job_id}_results.zip")
280
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
281
+ for file_path in output_files:
282
+ zf.write(file_path, os.path.basename(file_path))
283
+
284
+ job.output_zip_path = zip_path
285
+ logger.info(f"[{job_id}] Created ZIP: {zip_path}")
286
+
287
+ # Update job status
288
+ job.status = JobStatus.COMPLETED if job.failed_files == 0 else JobStatus.FAILED
289
+ job.completed_at = datetime.now()
290
+
291
+ return job
292
+
293
+ def _format_srt_time(self, seconds: float) -> str:
294
+ """Format seconds to SRT time format (HH:MM:SS,mmm)."""
295
+ hours = int(seconds // 3600)
296
+ minutes = int((seconds % 3600) // 60)
297
+ secs = int(seconds % 60)
298
+ millis = int((seconds % 1) * 1000)
299
+ return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"
300
+
301
+ def cancel_job(self, job_id: str) -> bool:
302
+ """Cancel a pending/processing job."""
303
+ job = self.get_job(job_id)
304
+ if job and job.status in [JobStatus.PENDING, JobStatus.PROCESSING]:
305
+ job.status = JobStatus.CANCELLED
306
+ return True
307
+ return False
308
+
309
+ def delete_job(self, job_id: str) -> bool:
310
+ """Delete a job and its output files."""
311
+ job = _batch_jobs.pop(job_id, None)
312
+ if job:
313
+ # Clean up files
314
+ if job.output_zip_path and os.path.exists(job.output_zip_path):
315
+ try:
316
+ os.unlink(job.output_zip_path)
317
+ except:
318
+ pass
319
+
320
+ job_dir = os.path.join(self.output_dir, job_id)
321
+ if os.path.exists(job_dir):
322
+ try:
323
+ import shutil
324
+ shutil.rmtree(job_dir)
325
+ except:
326
+ pass
327
+
328
+ return True
329
+ return False
330
+
331
+ def get_zip_path(self, job_id: str) -> Optional[str]:
332
+ """Get path to job's output ZIP file."""
333
+ job = self.get_job(job_id)
334
+ if job and job.output_zip_path and os.path.exists(job.output_zip_path):
335
+ return job.output_zip_path
336
+ return None
337
+
338
+
339
+ # Singleton instance
340
+ _batch_service: Optional[BatchProcessingService] = None
341
+
342
+
343
+ def get_batch_service() -> BatchProcessingService:
344
+ """Get or create BatchProcessingService singleton."""
345
+ global _batch_service
346
+ if _batch_service is None:
347
+ _batch_service = BatchProcessingService()
348
+ return _batch_service
backend/app/services/cache_service.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import redis
2
+ import json
3
+ import hashlib
4
+ import logging
5
+ from typing import Optional, Any
6
+ from functools import lru_cache
7
+
8
+ from ..core.config import get_settings
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class CacheService:
13
+ def __init__(self):
14
+ settings = get_settings()
15
+ self.default_ttl = 3600 # 1 hour
16
+ self.redis = None
17
+ self.disk_cache = None
18
+
19
+ # Try Redis first
20
+ try:
21
+ self.redis = redis.from_url(settings.redis_url, decode_responses=False)
22
+ self.redis.ping()
23
+ logger.info("✅ Redis Cache connected")
24
+ except Exception as e:
25
+ logger.warning(f"⚠️ Redis unavailable, falling back to DiskCache: {e}")
26
+ self.redis = None
27
+
28
+ # Fallback to DiskCache
29
+ try:
30
+ import diskcache
31
+ cache_dir = "./cache_data"
32
+ self.disk_cache = diskcache.Cache(cache_dir)
33
+ logger.info(f"💾 DiskCache initialized at {cache_dir}")
34
+ except Exception as e:
35
+ logger.error(f"❌ DiskCache init failed: {e}")
36
+
37
+ def get(self, key: str) -> Optional[bytes]:
38
+ """Get raw bytes from cache"""
39
+ try:
40
+ if self.redis:
41
+ return self.redis.get(key)
42
+ elif self.disk_cache:
43
+ return self.disk_cache.get(key)
44
+ except Exception as e:
45
+ logger.error(f"Cache get failed: {e}")
46
+ return None
47
+
48
+ def set(self, key: str, value: bytes, ttl: int = None):
49
+ """Set raw bytes in cache"""
50
+ try:
51
+ ttl_val = ttl or self.default_ttl
52
+
53
+ if self.redis:
54
+ self.redis.setex(key, ttl_val, value)
55
+ elif self.disk_cache:
56
+ self.disk_cache.set(key, value, expire=ttl_val)
57
+ except Exception as e:
58
+ logger.error(f"Cache set failed: {e}")
59
+
60
+ def generate_key(self, prefix: str, **kwargs) -> str:
61
+ """Generate a stable cache key from arguments"""
62
+ # Convert all values to string for stability
63
+ safe_kwargs = {k: str(v) for k, v in kwargs.items()}
64
+ sorted_kwargs = dict(sorted(safe_kwargs.items()))
65
+ key_str = json.dumps(sorted_kwargs, sort_keys=True)
66
+ hash_str = hashlib.md5(key_str.encode()).hexdigest()
67
+ return f"{prefix}:{hash_str}"
68
+
69
+ @lru_cache()
70
+ def get_cache_service() -> CacheService:
71
+ return CacheService()
backend/app/services/clone_service.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Voice Cloning Service (Coqui XTTS)
3
+ High-quality multi-lingual text-to-speech with voice cloning capabilities.
4
+ """
5
+
6
+ import os
7
+ import logging
8
+ import torch
9
+ import gc
10
+ from typing import List, Optional, Dict, Any
11
+ from pathlib import Path
12
+ import tempfile
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class CloneService:
17
+ """
18
+ Service for Voice Cloning using Coqui XTTS v2.
19
+ """
20
+
21
+ def __init__(self):
22
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ self.tts = None
24
+ self.model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
25
+ self.loaded = False
26
+
27
+ def load_model(self):
28
+ """Lazy load the heavy XTTS model"""
29
+ if self.loaded:
30
+ return
31
+
32
+ try:
33
+ logger.info(f"Loading XTTS model ({self.device})... This may take a while.")
34
+ from TTS.api import TTS
35
+
36
+ # Load model
37
+ self.tts = TTS(self.model_name).to(self.device)
38
+ self.loaded = True
39
+ logger.info("✅ XTTS Model loaded successfully")
40
+
41
+ except ImportError as e:
42
+ logger.error("TTS library not installed. Please install 'TTS'.")
43
+ raise ImportError("Voice Cloning requires 'TTS' library.")
44
+ except Exception as e:
45
+ logger.error(f"Failed to load XTTS model: {e}")
46
+ raise e
47
+
48
+ def unload_model(self):
49
+ """Unload model to free VRAM"""
50
+ if self.tts:
51
+ del self.tts
52
+ self.tts = None
53
+ self.loaded = False
54
+ gc.collect()
55
+ torch.cuda.empty_cache()
56
+ logger.info("🗑️ XTTS Model unloaded")
57
+
58
+ def clone_voice(
59
+ self,
60
+ text: str,
61
+ speaker_wav_paths: List[str],
62
+ language: str = "en",
63
+ output_path: Optional[str] = None
64
+ ) -> str:
65
+ """
66
+ Synthesize speech in the style of the reference audio.
67
+ """
68
+ if not self.loaded:
69
+ self.load_model()
70
+
71
+ if not output_path:
72
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
73
+ output_path = f.name
74
+
75
+ try:
76
+ # XTTS synthesis
77
+ # Note: speaker_wav can be a list of files for better cloning
78
+ self.tts.tts_to_file(
79
+ text=text,
80
+ speaker_wav=speaker_wav_paths,
81
+ language=language,
82
+ file_path=output_path,
83
+ split_sentences=True
84
+ )
85
+
86
+ logger.info(f"Cloned speech generated: {output_path}")
87
+ return output_path
88
+
89
+ except Exception as e:
90
+ logger.error(f"Cloning failed: {e}")
91
+ raise e
92
+
93
+ def get_supported_languages(self) -> List[str]:
94
+ # XTTS v2 supported languages
95
+ return ["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "hu", "ko"]
96
+
97
+ # Singleton
98
+ _clone_service = None
99
+
100
+ def get_clone_service():
101
+ global _clone_service
102
+ if _clone_service is None:
103
+ _clone_service = CloneService()
104
+ return _clone_service
backend/app/services/diarization_service.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speaker Diarization Service - Clean Implementation
3
+ Uses faster-whisper + pyannote.audio directly (no whisperx)
4
+
5
+ This avoids the KeyError bugs in whisperx alignment while providing
6
+ the same functionality.
7
+ """
8
+
9
+ import os
10
+ import gc
11
+ import logging
12
+ import torch
13
+ from typing import Optional, Dict, Any, List
14
+ from dotenv import load_dotenv
15
+
16
+ from app.core.config import get_settings
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Load environment variables from .env file
21
+ load_dotenv()
22
+
23
+ # Workaround for PyTorch 2.6+ weights_only security restriction
24
+ os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
25
+
26
+
27
+ class DiarizationService:
28
+ """
29
+ Speaker Diarization Service using faster-whisper + pyannote.audio.
30
+
31
+ This implementation avoids whisperx entirely to prevent alignment bugs.
32
+
33
+ Flow:
34
+ 1. Transcribe with faster-whisper (word-level timestamps)
35
+ 2. Diarize with pyannote.audio (speaker segments)
36
+ 3. Merge speakers with transcript segments
37
+
38
+ Requires:
39
+ - faster-whisper (already installed)
40
+ - pyannote.audio
41
+ - Valid Hugging Face Token (HF_TOKEN) in .env
42
+ """
43
+
44
+ def __init__(self):
45
+ self.settings = get_settings()
46
+
47
+ # Auto-detect GPU (prefer CUDA for speed)
48
+ if torch.cuda.is_available():
49
+ self.device = "cuda"
50
+ self.compute_type = "float16"
51
+ logger.info(f"🚀 Diarization using GPU: {torch.cuda.get_device_name(0)}")
52
+ else:
53
+ self.device = "cpu"
54
+ self.compute_type = "int8"
55
+ logger.info("⚠️ Diarization using CPU (slower)")
56
+
57
+ # Load HF token
58
+ self.hf_token = os.getenv("HF_TOKEN")
59
+ if not self.hf_token:
60
+ logger.warning("⚠️ HF_TOKEN not found. Speaker diarization will fail.")
61
+
62
+ # FFmpeg Setup for Windows
63
+ self._setup_ffmpeg()
64
+
65
+ def _setup_ffmpeg(self):
66
+ """Auto-configure FFmpeg from imageio-ffmpeg if not in PATH"""
67
+ try:
68
+ import imageio_ffmpeg
69
+ import shutil
70
+
71
+ ffmpeg_src = imageio_ffmpeg.get_ffmpeg_exe()
72
+ backend_dir = os.getcwd()
73
+ ffmpeg_dest = os.path.join(backend_dir, "ffmpeg.exe")
74
+
75
+ if not os.path.exists(ffmpeg_dest):
76
+ shutil.copy(ffmpeg_src, ffmpeg_dest)
77
+ logger.info(f"🔧 Configured FFmpeg: {ffmpeg_dest}")
78
+
79
+ if backend_dir not in os.environ.get("PATH", ""):
80
+ os.environ["PATH"] = backend_dir + os.pathsep + os.environ.get("PATH", "")
81
+
82
+ except Exception as e:
83
+ logger.warning(f"⚠️ Could not auto-configure FFmpeg: {e}")
84
+
85
+ def check_requirements(self):
86
+ """Validate requirements before processing"""
87
+ if not self.hf_token:
88
+ raise ValueError(
89
+ "HF_TOKEN is missing. Add HF_TOKEN=your_token to .env file. "
90
+ "Get one at: https://huggingface.co/settings/tokens"
91
+ )
92
+
93
+ def _get_diarization_pipeline(self):
94
+ """Load pyannote diarization pipeline with PyTorch 2.6+ fix"""
95
+ from pyannote.audio import Pipeline
96
+
97
+ # Monkey-patch torch.load for PyTorch 2.6+ compatibility
98
+ original_load = torch.load
99
+ def safe_load(*args, **kwargs):
100
+ kwargs.pop('weights_only', None)
101
+ return original_load(*args, **kwargs, weights_only=False)
102
+
103
+ torch.load = safe_load
104
+ try:
105
+ pipeline = Pipeline.from_pretrained(
106
+ "pyannote/speaker-diarization-3.1",
107
+ use_auth_token=self.hf_token
108
+ )
109
+ if self.device == "cuda":
110
+ pipeline.to(torch.device("cuda"))
111
+ return pipeline
112
+ finally:
113
+ torch.load = original_load
114
+
115
+ def _transcribe_with_timestamps(self, audio_path: str, language: Optional[str] = None) -> Dict:
116
+ """Transcribe audio using faster-whisper with word timestamps"""
117
+ from faster_whisper import WhisperModel
118
+
119
+ # CTranslate2 (faster-whisper) doesn't support float16 on all GPUs
120
+ # Use int8 for whisper, but pyannote still benefits from CUDA
121
+ whisper_compute = "int8" if self.device == "cuda" else "int8"
122
+ model = WhisperModel(
123
+ "small",
124
+ device=self.device,
125
+ compute_type=whisper_compute
126
+ )
127
+
128
+ segments_raw, info = model.transcribe(
129
+ audio_path,
130
+ language=language,
131
+ word_timestamps=True,
132
+ vad_filter=True
133
+ )
134
+
135
+ segments = []
136
+ for segment in segments_raw:
137
+ segments.append({
138
+ "start": segment.start,
139
+ "end": segment.end,
140
+ "text": segment.text.strip(),
141
+ "words": [
142
+ {"start": w.start, "end": w.end, "word": w.word}
143
+ for w in (segment.words or [])
144
+ ]
145
+ })
146
+
147
+ # Cleanup
148
+ del model
149
+ gc.collect()
150
+
151
+ return {
152
+ "segments": segments,
153
+ "language": info.language
154
+ }
155
+
156
+ def _preprocess_audio(self, audio_path: str) -> str:
157
+ """
158
+ Apply noise reduction to audio file.
159
+ Returns path to cleaned audio file.
160
+ """
161
+ try:
162
+ import noisereduce as nr
163
+ import librosa
164
+ import soundfile as sf
165
+ import tempfile
166
+
167
+ logger.info("🔧 Preprocessing audio (noise reduction)...")
168
+
169
+ # Load audio
170
+ audio, sr = librosa.load(audio_path, sr=16000, mono=True)
171
+
172
+ # Apply spectral gating noise reduction
173
+ reduced_noise = nr.reduce_noise(
174
+ y=audio,
175
+ sr=sr,
176
+ stationary=True,
177
+ prop_decrease=0.75
178
+ )
179
+
180
+ # Save to temp file
181
+ temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
182
+ sf.write(temp_file.name, reduced_noise, sr)
183
+
184
+ logger.info(f" → Noise reduction complete, saved to {temp_file.name}")
185
+ return temp_file.name
186
+
187
+ except ImportError as e:
188
+ logger.warning(f"⚠️ Audio preprocessing unavailable (install noisereduce, librosa, soundfile): {e}")
189
+ return audio_path
190
+ except Exception as e:
191
+ logger.warning(f"⚠️ Audio preprocessing failed: {e}")
192
+ return audio_path
193
+
194
+ def _merge_speakers(self, transcript: Dict, diarization) -> List[Dict]:
195
+ """
196
+ Merge speaker labels from diarization with transcript segments.
197
+
198
+ Uses midpoint matching with nearest-speaker fallback to minimize UNKNOWN labels.
199
+ """
200
+ segments = transcript["segments"]
201
+ result = []
202
+
203
+ # Build list of speaker turns for efficient lookup
204
+ speaker_turns = [
205
+ (turn.start, turn.end, spk)
206
+ for turn, _, spk in diarization.itertracks(yield_label=True)
207
+ ]
208
+
209
+ for seg in segments:
210
+ mid_time = (seg["start"] + seg["end"]) / 2
211
+ speaker = None
212
+
213
+ # Step 1: Try exact midpoint match
214
+ for start, end, spk in speaker_turns:
215
+ if start <= mid_time <= end:
216
+ speaker = spk
217
+ break
218
+
219
+ # Step 2: If no match, find nearest speaker (fallback)
220
+ if speaker is None and speaker_turns:
221
+ min_distance = float('inf')
222
+ for start, end, spk in speaker_turns:
223
+ # Distance to nearest edge of speaker segment
224
+ if mid_time < start:
225
+ dist = start - mid_time
226
+ elif mid_time > end:
227
+ dist = mid_time - end
228
+ else:
229
+ dist = 0 # Should have been caught above
230
+
231
+ if dist < min_distance:
232
+ min_distance = dist
233
+ speaker = spk
234
+
235
+ # Final fallback (shouldn't happen)
236
+ if speaker is None:
237
+ speaker = "UNKNOWN"
238
+
239
+ result.append({
240
+ "start": seg["start"],
241
+ "end": seg["end"],
242
+ "text": seg["text"],
243
+ "speaker": speaker
244
+ })
245
+
246
+ return result
247
+
248
+ def process_audio(
249
+ self,
250
+ audio_path: str,
251
+ num_speakers: Optional[int] = None,
252
+ min_speakers: Optional[int] = None,
253
+ max_speakers: Optional[int] = None,
254
+ language: Optional[str] = None,
255
+ preprocess: bool = False,
256
+ ) -> Dict[str, Any]:
257
+ """
258
+ Full diarization pipeline: [Preprocess] → Transcribe → Diarize → Merge
259
+
260
+ Args:
261
+ audio_path: Path to audio file
262
+ num_speakers: Exact number of speakers (optional)
263
+ min_speakers: Minimum speakers (optional)
264
+ max_speakers: Maximum speakers (optional)
265
+ language: Force language code (optional, auto-detected if None)
266
+ preprocess: Apply noise reduction before processing (default: False)
267
+
268
+ Returns:
269
+ Dict with segments, speaker_stats, language, status
270
+ """
271
+ self.check_requirements()
272
+
273
+ logger.info(f"🎤 Starting diarization on {self.device}...")
274
+
275
+ # Optional preprocessing for noise reduction
276
+ processed_path = audio_path
277
+ if preprocess:
278
+ processed_path = self._preprocess_audio(audio_path)
279
+
280
+ try:
281
+ # Step 1: Transcribe with faster-whisper
282
+ logger.info("Step 1/3: Transcribing audio...")
283
+ transcript = self._transcribe_with_timestamps(processed_path, language)
284
+ detected_lang = transcript["language"]
285
+ logger.info(f" → Language: {detected_lang}, Segments: {len(transcript['segments'])}")
286
+
287
+ # Step 2: Diarize with pyannote
288
+ logger.info("Step 2/3: Identifying speakers...")
289
+ pipeline = self._get_diarization_pipeline()
290
+
291
+ diarization = pipeline(
292
+ processed_path,
293
+ num_speakers=num_speakers,
294
+ min_speakers=min_speakers,
295
+ max_speakers=max_speakers
296
+ )
297
+
298
+ # Cleanup pipeline
299
+ del pipeline
300
+ gc.collect()
301
+
302
+ # Step 3: Merge results
303
+ logger.info("Step 3/3: Merging speakers with transcript...")
304
+ segments = self._merge_speakers(transcript, diarization)
305
+
306
+ # Calculate speaker stats
307
+ speaker_stats = {}
308
+ for seg in segments:
309
+ spk = seg["speaker"]
310
+ dur = seg["end"] - seg["start"]
311
+ speaker_stats[spk] = speaker_stats.get(spk, 0) + dur
312
+
313
+ logger.info(f"✅ Diarization complete: {len(segments)} segments, {len(speaker_stats)} speakers")
314
+
315
+ return {
316
+ "segments": segments,
317
+ "speaker_stats": speaker_stats,
318
+ "language": detected_lang,
319
+ "status": "success"
320
+ }
321
+
322
+ except Exception as e:
323
+ logger.exception("Diarization failed")
324
+ raise e
325
+ finally:
326
+ gc.collect()
327
+ if self.device == "cuda":
328
+ torch.cuda.empty_cache()
329
+
330
+
331
+ # Singleton
332
+ _diarization_service = None
333
+
334
+ def get_diarization_service():
335
+ global _diarization_service
336
+ if not _diarization_service:
337
+ _diarization_service = DiarizationService()
338
+ return _diarization_service
backend/app/services/edge_tts_service.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Edge-TTS Text-to-Speech Service
3
+ Free, high-quality neural TTS using Microsoft Edge's speech synthesis
4
+ """
5
+
6
+ import asyncio
7
+ import io
8
+ import logging
9
+ import edge_tts
10
+ from typing import Optional, List, Dict, Any
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ # Available voice samples by language
16
+ VOICE_CATALOG = {
17
+ "en-US": [
18
+ {"name": "en-US-AriaNeural", "gender": "Female", "style": "professional"},
19
+ {"name": "en-US-GuyNeural", "gender": "Male", "style": "casual"},
20
+ {"name": "en-US-JennyNeural", "gender": "Female", "style": "friendly"},
21
+ {"name": "en-US-ChristopherNeural", "gender": "Male", "style": "newscast"},
22
+ ],
23
+ "en-GB": [
24
+ {"name": "en-GB-SoniaNeural", "gender": "Female", "style": "professional"},
25
+ {"name": "en-GB-RyanNeural", "gender": "Male", "style": "casual"},
26
+ ],
27
+ "en-IN": [
28
+ {"name": "en-IN-NeerjaNeural", "gender": "Female", "style": "professional"},
29
+ {"name": "en-IN-PrabhatNeural", "gender": "Male", "style": "casual"},
30
+ ],
31
+ "hi-IN": [
32
+ {"name": "hi-IN-SwaraNeural", "gender": "Female", "style": "professional"},
33
+ {"name": "hi-IN-MadhurNeural", "gender": "Male", "style": "casual"},
34
+ ],
35
+ "es-ES": [
36
+ {"name": "es-ES-ElviraNeural", "gender": "Female", "style": "professional"},
37
+ {"name": "es-ES-AlvaroNeural", "gender": "Male", "style": "casual"},
38
+ ],
39
+ "es-MX": [
40
+ {"name": "es-MX-DaliaNeural", "gender": "Female", "style": "professional"},
41
+ {"name": "es-MX-JorgeNeural", "gender": "Male", "style": "casual"},
42
+ ],
43
+ "fr-FR": [
44
+ {"name": "fr-FR-DeniseNeural", "gender": "Female", "style": "professional"},
45
+ {"name": "fr-FR-HenriNeural", "gender": "Male", "style": "casual"},
46
+ ],
47
+ "de-DE": [
48
+ {"name": "de-DE-KatjaNeural", "gender": "Female", "style": "professional"},
49
+ {"name": "de-DE-ConradNeural", "gender": "Male", "style": "casual"},
50
+ ],
51
+ "ja-JP": [
52
+ {"name": "ja-JP-NanamiNeural", "gender": "Female", "style": "professional"},
53
+ {"name": "ja-JP-KeitaNeural", "gender": "Male", "style": "casual"},
54
+ ],
55
+ "ko-KR": [
56
+ {"name": "ko-KR-SunHiNeural", "gender": "Female", "style": "professional"},
57
+ {"name": "ko-KR-InJoonNeural", "gender": "Male", "style": "casual"},
58
+ ],
59
+ "zh-CN": [
60
+ {"name": "zh-CN-XiaoxiaoNeural", "gender": "Female", "style": "professional"},
61
+ {"name": "zh-CN-YunxiNeural", "gender": "Male", "style": "casual"},
62
+ ],
63
+ }
64
+
65
+
66
+ class EdgeTTSService:
67
+ """
68
+ Text-to-Speech service using Microsoft Edge TTS (free, neural voices)
69
+ """
70
+
71
+ def __init__(self):
72
+ """Initialize the Edge TTS service"""
73
+ self._all_voices = None
74
+
75
+ # Class-level cache
76
+ _voices_cache = None
77
+
78
+ async def get_voices(self, language: Optional[str] = None) -> List[Dict[str, Any]]:
79
+ """
80
+ Get available voices
81
+ """
82
+ # Check cache
83
+ if EdgeTTSService._voices_cache is None:
84
+ try:
85
+ voices = await edge_tts.list_voices()
86
+
87
+ # Transform to our format
88
+ formatted_voices = []
89
+ for v in voices:
90
+ formatted_voices.append({
91
+ "name": v["ShortName"],
92
+ "display_name": v["ShortName"].replace("-", " ").split("Neural")[0].strip(),
93
+ "language_code": v["Locale"],
94
+ "gender": v["Gender"],
95
+ "voice_type": "Neural",
96
+ })
97
+
98
+ EdgeTTSService._voices_cache = formatted_voices
99
+ except Exception as e:
100
+ logger.error(f"Failed to fetch voices from Edge TTS: {e}. Falling back to catalog.")
101
+ # Fallback to catalog
102
+ voices = []
103
+ for lang, lang_voices in VOICE_CATALOG.items():
104
+ for v in lang_voices:
105
+ voices.append({
106
+ "name": v["name"],
107
+ "display_name": v["name"].replace("-", " ").replace("Neural", "").strip(),
108
+ "language_code": lang,
109
+ "gender": v["gender"],
110
+ "voice_type": "Neural",
111
+ })
112
+ EdgeTTSService._voices_cache = voices
113
+
114
+ voices = EdgeTTSService._voices_cache
115
+
116
+ # Filter by language if specified
117
+ if language:
118
+ voices = [v for v in voices if v["language_code"].startswith(language)]
119
+
120
+ return voices
121
+
122
+ def get_voices_sync(self, language: Optional[str] = None) -> List[Dict[str, Any]]:
123
+ """Synchronous wrapper for get_voices"""
124
+ # Create a new event loop if necessary for sync wrapper
125
+ try:
126
+ loop = asyncio.get_event_loop()
127
+ except RuntimeError:
128
+ loop = asyncio.new_event_loop()
129
+ asyncio.set_event_loop(loop)
130
+
131
+ if loop.is_running():
132
+ # If loop is running, we can't block it.
133
+ import concurrent.futures
134
+ with concurrent.futures.ThreadPoolExecutor() as pool:
135
+ future = asyncio.run_coroutine_threadsafe(self.get_voices(language), loop)
136
+ return future.result()
137
+
138
+ return loop.run_until_complete(self.get_voices(language))
139
+
140
+ def build_ssml(
141
+ self,
142
+ text: str,
143
+ voice: str = "en-US-AriaNeural",
144
+ rate: str = "medium",
145
+ pitch: str = "medium",
146
+ emphasis: str = None,
147
+ breaks: bool = True
148
+ ) -> str:
149
+ """
150
+ Build SSML markup for advanced prosody control.
151
+
152
+ Args:
153
+ text: Plain text to convert
154
+ voice: Voice name
155
+ rate: Speed - 'x-slow', 'slow', 'medium', 'fast', 'x-fast' or percentage
156
+ pitch: Pitch - 'x-low', 'low', 'medium', 'high', 'x-high' or Hz offset
157
+ emphasis: Optional emphasis level - 'reduced', 'moderate', 'strong'
158
+ breaks: Auto-insert breaks at punctuation
159
+
160
+ Returns:
161
+ SSML-formatted string
162
+ """
163
+ # Normalize rate/pitch values
164
+ rate_value = rate if rate in ['x-slow', 'slow', 'medium', 'fast', 'x-fast'] else rate
165
+ pitch_value = pitch if pitch in ['x-low', 'low', 'medium', 'high', 'x-high'] else pitch
166
+
167
+ # Build SSML
168
+ ssml_parts = ['<speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis" xml:lang="en-US">']
169
+ ssml_parts.append(f'<voice name="{voice}">')
170
+ ssml_parts.append(f'<prosody rate="{rate_value}" pitch="{pitch_value}">')
171
+
172
+ if emphasis:
173
+ ssml_parts.append(f'<emphasis level="{emphasis}">')
174
+
175
+ # Auto-insert breaks for natural speech
176
+ if breaks:
177
+ import re
178
+ # Add short breaks after commas, longer after periods
179
+ processed_text = re.sub(r'([,;:])\s*', r'\1<break time="200ms"/>', text)
180
+ processed_text = re.sub(r'([.!?])\s+', r'\1<break time="500ms"/>', processed_text)
181
+ ssml_parts.append(processed_text)
182
+ else:
183
+ ssml_parts.append(text)
184
+
185
+ if emphasis:
186
+ ssml_parts.append('</emphasis>')
187
+
188
+ ssml_parts.append('</prosody>')
189
+ ssml_parts.append('</voice>')
190
+ ssml_parts.append('</speak>')
191
+
192
+ return ''.join(ssml_parts)
193
+
194
+ async def synthesize_ssml(
195
+ self,
196
+ ssml_text: str,
197
+ voice: str = "en-US-AriaNeural",
198
+ ) -> bytes:
199
+ """
200
+ Synthesize speech from SSML markup.
201
+
202
+ Args:
203
+ ssml_text: SSML-formatted text
204
+ voice: Voice name (for edge-tts communication)
205
+
206
+ Returns:
207
+ Audio bytes (MP3)
208
+ """
209
+ logger.info(f"Synthesizing SSML with voice: {voice}")
210
+
211
+ # Edge TTS handles SSML natively
212
+ communicate = edge_tts.Communicate(ssml_text, voice)
213
+
214
+ audio_buffer = io.BytesIO()
215
+ async for chunk in communicate.stream():
216
+ if chunk["type"] == "audio":
217
+ audio_buffer.write(chunk["data"])
218
+
219
+ audio_buffer.seek(0)
220
+ return audio_buffer.read()
221
+
222
+
223
+ async def synthesize_stream(
224
+ self,
225
+ text: str,
226
+ voice: str = "en-US-AriaNeural",
227
+ rate: str = "+0%",
228
+ pitch: str = "+0Hz",
229
+ ):
230
+ """
231
+ Stream speech synthesis chunks.
232
+
233
+ Optimized to stream sentence-by-sentence to reduce TTFB (Time To First Byte),
234
+ avoiding full-text buffering issues.
235
+ """
236
+ import re
237
+
238
+ # Split text into sentences to force incremental processing
239
+ # This regex matches sentences ending with . ! ? or end of string
240
+ # It keeps the proper punctuation.
241
+ sentences = re.findall(r'[^.!?]+(?:[.!?]+|$)', text)
242
+ if not sentences:
243
+ sentences = [text]
244
+
245
+ logger.info(f"Streaming {len(sentences)} sentences for low latency...")
246
+
247
+ for sentence in sentences:
248
+ if not sentence.strip():
249
+ continue
250
+
251
+ communicate = edge_tts.Communicate(sentence, voice, rate=rate, pitch=pitch)
252
+
253
+ async for chunk in communicate.stream():
254
+ if chunk["type"] == "audio":
255
+ yield chunk["data"]
256
+
257
+ async def synthesize(
258
+ self,
259
+ text: str,
260
+ voice: str = "en-US-AriaNeural",
261
+ rate: str = "+0%",
262
+ pitch: str = "+0Hz",
263
+ ) -> bytes:
264
+ """
265
+ Synthesize speech from text
266
+
267
+ Args:
268
+ text: Text to synthesize
269
+ voice: Voice name (e.g., 'en-US-AriaNeural')
270
+ rate: Speaking rate adjustment (e.g., '+20%', '-10%')
271
+ pitch: Pitch adjustment (e.g., '+5Hz', '-10Hz')
272
+
273
+ Returns:
274
+ Audio content as bytes (MP3 format)
275
+ """
276
+ # Reuse stream method to avoid duplication
277
+ audio_buffer = io.BytesIO()
278
+ async for chunk in self.synthesize_stream(text, voice, rate, pitch):
279
+ audio_buffer.write(chunk)
280
+
281
+ audio_buffer.seek(0)
282
+ return audio_buffer.read()
283
+
284
+ def synthesize_sync(
285
+ self,
286
+ text: str,
287
+ voice: str = "en-US-AriaNeural",
288
+ rate: str = "+0%",
289
+ pitch: str = "+0Hz",
290
+ ) -> bytes:
291
+ """Synchronous wrapper for synthesize"""
292
+ try:
293
+ loop = asyncio.get_event_loop()
294
+ except RuntimeError:
295
+ loop = asyncio.new_event_loop()
296
+ asyncio.set_event_loop(loop)
297
+
298
+ return loop.run_until_complete(self.synthesize(text, voice, rate, pitch))
299
+
300
+ async def synthesize_to_response(
301
+ self,
302
+ text: str,
303
+ voice: str = "en-US-AriaNeural",
304
+ speaking_rate: float = 1.0,
305
+ pitch: float = 0.0,
306
+ ) -> Dict[str, Any]:
307
+ """
308
+ Synthesize speech and return API-compatible response
309
+
310
+ Args:
311
+ text: Text to synthesize
312
+ voice: Voice name
313
+ speaking_rate: Rate multiplier (1.0 = normal, 1.5 = 50% faster)
314
+ pitch: Pitch adjustment in semitones (-20 to +20)
315
+
316
+ Returns:
317
+ Dictionary with audio content and metadata
318
+ """
319
+ import base64
320
+ import time
321
+
322
+ start_time = time.time()
323
+
324
+ # Convert rate/pitch to Edge TTS format
325
+ rate_percent = int((speaking_rate - 1.0) * 100)
326
+ rate_str = f"+{rate_percent}%" if rate_percent >= 0 else f"{rate_percent}%"
327
+ pitch_str = f"+{int(pitch)}Hz" if pitch >= 0 else f"{int(pitch)}Hz"
328
+
329
+ # Synthesize
330
+ audio_bytes = await self.synthesize(text, voice, rate_str, pitch_str)
331
+
332
+ processing_time = time.time() - start_time
333
+
334
+ # Estimate duration (~150 chars per second at normal speed)
335
+ estimated_duration = len(text) / 150 / speaking_rate
336
+
337
+ return {
338
+ "audio_content": base64.b64encode(audio_bytes).decode("utf-8"),
339
+ "encoding": "MP3",
340
+ "audio_size": len(audio_bytes),
341
+ "duration_estimate": estimated_duration,
342
+ "voice_used": voice,
343
+ "processing_time": processing_time,
344
+ "cached": False,
345
+ }
346
+
347
+
348
+ # Singleton instance
349
+ _edge_tts_service: Optional[EdgeTTSService] = None
350
+
351
+
352
+ def get_edge_tts_service() -> EdgeTTSService:
353
+ """Get or create the EdgeTTSService singleton"""
354
+ global _edge_tts_service
355
+ if _edge_tts_service is None:
356
+ _edge_tts_service = EdgeTTSService()
357
+ return _edge_tts_service
backend/app/services/emotion_service.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Emotion Analysis Service
3
+ Detects emotion from audio using Wav2Vec2 and text using NLP
4
+ """
5
+
6
+ import logging
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from typing import Dict, List, Any, Optional
12
+
13
+ from app.core.config import get_settings
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class EmotionService:
19
+ """
20
+ Service for Speech Emotion Recognition (SER).
21
+ Uses 'ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition'
22
+ """
23
+
24
+ def __init__(self):
25
+ self.model_name = "ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition"
26
+ self._model = None
27
+ self._processor = None
28
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
29
+
30
+ # Supported emotions in model's order
31
+ self.emotions = [
32
+ "angry", "calm", "disgust", "fearful",
33
+ "happy", "neutral", "sad", "surprised"
34
+ ]
35
+
36
+ def _load_model(self):
37
+ """Lazy load model to save RAM"""
38
+ if self._model is None:
39
+ try:
40
+ from transformers import Wav2Vec2Processor, Wav2Vec2ForSequenceClassification
41
+
42
+ logger.info(f"🎭 Loading Emotion Model ({self.device})...")
43
+ self._processor = Wav2Vec2Processor.from_pretrained(self.model_name)
44
+ self._model = Wav2Vec2ForSequenceClassification.from_pretrained(self.model_name)
45
+ self._model.to(self.device)
46
+ logger.info("✅ Emotion Model loaded")
47
+ except Exception as e:
48
+ logger.error(f"Failed to load emotion model: {e}")
49
+ raise
50
+
51
+ def analyze_audio(self, audio_path: str) -> Dict[str, Any]:
52
+ """
53
+ Analyze emotion of an entire audio file.
54
+
55
+ Args:
56
+ audio_path: Path to audio file
57
+
58
+ Returns:
59
+ Dict with dominant emotion and probability distribution
60
+ """
61
+ import librosa
62
+
63
+ self._load_model()
64
+
65
+ try:
66
+ # Load audio using librosa (16kHz required for Wav2Vec2)
67
+ # Duration limit: Analyze first 30s max for MVP to avoid OOM
68
+ # For full file, we should chunk it.
69
+ y, sr = librosa.load(audio_path, sr=16000, duration=60)
70
+
71
+ inputs = self._processor(y, sampling_rate=16000, return_tensors="pt", padding=True)
72
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
73
+
74
+ with torch.no_grad():
75
+ logits = self._model(**inputs).logits
76
+
77
+ # Get probabilities
78
+ probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
79
+
80
+ # Map to emotions
81
+ scores = {
82
+ self.emotions[i]: float(probs[i])
83
+ for i in range(len(self.emotions))
84
+ }
85
+
86
+ # Get dominant
87
+ dominant = max(scores, key=scores.get)
88
+
89
+ return {
90
+ "dominant_emotion": dominant,
91
+ "confidence": scores[dominant],
92
+ "distribution": scores
93
+ }
94
+
95
+ except Exception as e:
96
+ logger.error(f"Audio emotion analysis failed: {e}")
97
+ raise e
98
+
99
+ def analyze_audio_segment(self, audio_data: np.ndarray, sr: int = 16000) -> Dict[str, Any]:
100
+ """
101
+ Analyze a raw numpy audio segment.
102
+ """
103
+ self._load_model()
104
+
105
+ try:
106
+ inputs = self._processor(audio_data, sampling_rate=sr, return_tensors="pt", padding=True)
107
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
108
+
109
+ with torch.no_grad():
110
+ logits = self._model(**inputs).logits
111
+
112
+ probs = F.softmax(logits, dim=-1)[0].cpu().numpy()
113
+ scores = {self.emotions[i]: float(probs[i]) for i in range(len(self.emotions))}
114
+ dominant = max(scores, key=scores.get)
115
+
116
+ return {
117
+ "emotion": dominant,
118
+ "score": scores[dominant]
119
+ }
120
+ except Exception as e:
121
+ logger.error(f"Segment analysis failed: {e}")
122
+ return {"emotion": "neutral", "score": 0.0}
123
+
124
+
125
+ # Singleton
126
+ _emotion_service = None
127
+
128
+ def get_emotion_service() -> EmotionService:
129
+ global _emotion_service
130
+ if _emotion_service is None:
131
+ _emotion_service = EmotionService()
132
+ return _emotion_service