TextToAudio / app.py
vidhi0405's picture
auth
e88ed83
#!/usr/bin/env python3
import asyncio
import io
import os
import uuid
from datetime import datetime
from functools import lru_cache
from typing import Optional
import pymongo
import requests
import soundfile as sf
from bson.binary import Binary
from bson.objectid import ObjectId
from dotenv import load_dotenv
from fastapi import Body, FastAPI, Form, HTTPException, Request, Response
from pymongo.errors import PyMongoError
from pydantic import BaseModel
from model import ENGLISH_REPO_ID, get_pretrained_model
load_dotenv()
MONGO_URI = os.getenv("MONGO_URI", "").strip()
MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech").strip()
MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "audio").strip()
MONGO_CAPTIONS_COLLECTION = os.getenv("MONGO_CAPTIONS_COLLECTION", "captions").strip()
ERRORS = {
"TOKEN_MISSING": "firebase_id_token is missing",
"TOKEN_INVALID": "Invalid Firebase token",
}
def log(msg: str) -> None:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
print(f"{now}: {msg}")
@lru_cache(maxsize=1)
def _get_mongo_client():
if not MONGO_URI:
raise ValueError("MONGO_URI is missing in .env")
return pymongo.MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
def _get_mongo_collection():
client = _get_mongo_client()
return client[MONGO_DB_NAME][MONGO_COLLECTION]
def _get_captions_collection():
client = _get_mongo_client()
return client[MONGO_DB_NAME][MONGO_CAPTIONS_COLLECTION]
def _as_opus_bytes(samples, sample_rate: int) -> bytes:
buffer = io.BytesIO()
sf.write(buffer, samples, samplerate=sample_rate, format="OGG", subtype="OPUS")
return buffer.getvalue()
def _save_audio_to_db(
samples,
sample_rate: int,
caption_id: Optional[str] = None,
caption: Optional[str] = None,
) -> dict:
audio_id = str(uuid.uuid4())
duration = len(samples) / sample_rate
opus_bytes = _as_opus_bytes(samples, sample_rate)
audio_url = f"/audio/{audio_id}.opus"
doc = {
"audio_id": audio_id,
"audio_url": audio_url,
"audio_file": Binary(opus_bytes),
"sample_rate": int(sample_rate),
"duration_seconds": float(duration),
"audio_format": "opus",
"created_at": datetime.utcnow(),
}
if caption_id:
doc["caption_id"] = caption_id
if caption:
doc["caption"] = caption
inserted = _get_mongo_collection().insert_one(doc)
return {
"audio_file_id": str(inserted.inserted_id),
"audio_id": audio_id,
"audio_url": audio_url,
"sample_rate": int(sample_rate),
"duration_seconds": float(duration),
"caption_id": caption_id,
"caption": caption,
}
def _generate_audio_from_text(
text: str,
sid: int,
speed: float,
caption_id: Optional[str] = None,
) -> dict:
tts = get_pretrained_model(ENGLISH_REPO_ID, speed)
audio = tts.generate(text, sid=sid)
if len(audio.samples) == 0:
raise ValueError("No audio was generated.")
return _save_audio_to_db(
audio.samples,
audio.sample_rate,
caption_id=caption_id,
caption=text,
)
class AudioByIdRequest(BaseModel):
audio_id: str
sid: Optional[int] = 0
speed: Optional[float] = 1.0
firebase_id_token: Optional[str] = None
api = FastAPI(title="Text-to-Speech API")
def _api_response(succes: bool, messase: str, data):
return {"succes": succes, "messase": messase, "data": data}
@api.get("/health")
def health_check():
return {"status": "ok"}
def _find_audio_doc(identifier: str):
doc = _get_mongo_collection().find_one({"audio_id": identifier})
if doc:
return doc
if ObjectId.is_valid(identifier):
return _get_mongo_collection().find_one({"_id": ObjectId(identifier)})
return None
def _get_firebase_api_key() -> str:
# Read at runtime so env updates after restarts are reflected.
return (
os.getenv("FIREBASE_API_KEY", "").strip()
or os.getenv("FIREBASE_WEB_API_KEY", "").strip()
or os.getenv("FIREBASE_APIKEY", "").strip()
)
async def verify_firebase_token(firebase_id_token: str) -> dict:
"""Verify Firebase ID token using REST API."""
if not firebase_id_token:
raise HTTPException(status_code=401, detail=ERRORS["TOKEN_MISSING"])
firebase_api_key = _get_firebase_api_key()
if not firebase_api_key:
raise HTTPException(
status_code=500,
detail="FIREBASE_API_KEY is missing in environment configuration",
)
url = f"https://identitytoolkit.googleapis.com/v1/accounts:lookup?key={firebase_api_key}"
payload = {"idToken": firebase_id_token}
try:
resp = await asyncio.to_thread(requests.post, url, json=payload, timeout=10)
if resp.status_code != 200:
try:
detail = resp.json().get("error", {}).get("message", ERRORS["TOKEN_INVALID"])
except ValueError:
detail = ERRORS["TOKEN_INVALID"]
raise HTTPException(status_code=401, detail=f"Firebase token verification failed: {detail}")
try:
users = resp.json().get("users", [])
except ValueError:
raise HTTPException(status_code=502, detail="Invalid response from Firebase verification service")
if not users:
raise HTTPException(status_code=401, detail="Firebase token verification failed: no user found")
return users[0]
except requests.RequestException as e:
raise HTTPException(status_code=503, detail=f"Firebase verification service unavailable: {str(e)}")
@api.post("/audio/by-id")
async def get_audio_by_id(
request: Request,
payload: Optional[AudioByIdRequest] = Body(default=None),
audio_id: Optional[str] = Form(default=None),
sid: Optional[int] = Form(default=0),
speed: Optional[float] = Form(default=1.0),
firebase_id_token: Optional[str] = Form(default=None),
):
try:
resolved_audio_id = audio_id or (payload.audio_id if payload else None)
resolved_sid = payload.sid if payload and payload.sid is not None else sid
resolved_speed = payload.speed if payload and payload.speed is not None else speed
resolved_firebase_token = firebase_id_token or (payload.firebase_id_token if payload else None)
await verify_firebase_token(resolved_firebase_token)
if not resolved_audio_id:
return _api_response(False, "audio_id is required", None)
doc = _find_audio_doc(resolved_audio_id)
if not doc and ObjectId.is_valid(resolved_audio_id):
doc = _get_mongo_collection().find_one({"caption_id": resolved_audio_id})
if not doc:
caption_doc = _get_captions_collection().find_one({"_id": ObjectId(resolved_audio_id)})
if caption_doc:
caption_text = str(caption_doc.get("caption", "")).strip()
if caption_text:
try:
saved = _generate_audio_from_text(
caption_text,
sid=resolved_sid,
speed=resolved_speed,
caption_id=resolved_audio_id,
)
doc = _find_audio_doc(saved["audio_id"])
except Exception as e:
log(f"Error generating audio from caption {resolved_audio_id}: {e}")
if not doc:
return _api_response(False, "Audio not found", None)
audio_bytes = bytes(doc.get("audio_file", b""))
if not audio_bytes:
return _api_response(False, "Document found but audio_file is missing", None)
resolved_id = str(doc.get("audio_id") or doc.get("_id"))
audio_url = str(request.base_url) + f"audio/{resolved_id}.opus"
return _api_response(
True,
"Audio fetched successfully",
{
"audio_id": resolved_id,
"audio_url": audio_url,
"sample_rate": int(doc.get("sample_rate", 0)),
"duration_seconds": float(doc.get("duration_seconds", 0.0)),
"caption": doc.get("caption"),
},
)
except HTTPException:
raise
except PyMongoError as e:
log(f"MongoDB error in /audio/by-id: {e}")
raise HTTPException(status_code=503, detail="Database service unavailable")
except Exception as e:
log(f"Unhandled error in /audio/by-id: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@api.get("/audio/{audio_id}.opus")
def stream_audio(audio_id: str):
try:
doc = _find_audio_doc(audio_id)
if not doc:
return Response(status_code=404)
audio_bytes = bytes(doc.get("audio_file", b""))
if not audio_bytes:
return Response(status_code=404)
resolved_id = str(doc.get("audio_id") or doc.get("_id"))
return Response(
content=audio_bytes,
media_type="audio/ogg",
headers={
"Content-Disposition": f'inline; filename="{resolved_id}.opus"',
"Cache-Control": "public, max-age=31536000",
},
)
except PyMongoError as e:
log(f"MongoDB error in /audio/{{audio_id}}.opus: {e}")
return Response(status_code=503)
except Exception as e:
log(f"Unhandled error in /audio/{{audio_id}}.opus: {e}")
return Response(status_code=500)
app = api
log(f"FIREBASE API key present at startup: {bool(_get_firebase_api_key())}")
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)