Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import asyncio | |
| import io | |
| import os | |
| import uuid | |
| from datetime import datetime | |
| from functools import lru_cache | |
| from typing import Optional | |
| import pymongo | |
| import requests | |
| import soundfile as sf | |
| from bson.binary import Binary | |
| from bson.objectid import ObjectId | |
| from dotenv import load_dotenv | |
| from fastapi import Body, FastAPI, Form, HTTPException, Request, Response | |
| from pymongo.errors import PyMongoError | |
| from pydantic import BaseModel | |
| from model import ENGLISH_REPO_ID, get_pretrained_model | |
| load_dotenv() | |
| MONGO_URI = os.getenv("MONGO_URI", "").strip() | |
| MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech").strip() | |
| MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "audio").strip() | |
| MONGO_CAPTIONS_COLLECTION = os.getenv("MONGO_CAPTIONS_COLLECTION", "captions").strip() | |
| ERRORS = { | |
| "TOKEN_MISSING": "firebase_id_token is missing", | |
| "TOKEN_INVALID": "Invalid Firebase token", | |
| } | |
| def log(msg: str) -> None: | |
| now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") | |
| print(f"{now}: {msg}") | |
| def _get_mongo_client(): | |
| if not MONGO_URI: | |
| raise ValueError("MONGO_URI is missing in .env") | |
| return pymongo.MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000) | |
| def _get_mongo_collection(): | |
| client = _get_mongo_client() | |
| return client[MONGO_DB_NAME][MONGO_COLLECTION] | |
| def _get_captions_collection(): | |
| client = _get_mongo_client() | |
| return client[MONGO_DB_NAME][MONGO_CAPTIONS_COLLECTION] | |
| def _as_opus_bytes(samples, sample_rate: int) -> bytes: | |
| buffer = io.BytesIO() | |
| sf.write(buffer, samples, samplerate=sample_rate, format="OGG", subtype="OPUS") | |
| return buffer.getvalue() | |
| def _save_audio_to_db( | |
| samples, | |
| sample_rate: int, | |
| caption_id: Optional[str] = None, | |
| caption: Optional[str] = None, | |
| ) -> dict: | |
| audio_id = str(uuid.uuid4()) | |
| duration = len(samples) / sample_rate | |
| opus_bytes = _as_opus_bytes(samples, sample_rate) | |
| audio_url = f"/audio/{audio_id}.opus" | |
| doc = { | |
| "audio_id": audio_id, | |
| "audio_url": audio_url, | |
| "audio_file": Binary(opus_bytes), | |
| "sample_rate": int(sample_rate), | |
| "duration_seconds": float(duration), | |
| "audio_format": "opus", | |
| "created_at": datetime.utcnow(), | |
| } | |
| if caption_id: | |
| doc["caption_id"] = caption_id | |
| if caption: | |
| doc["caption"] = caption | |
| inserted = _get_mongo_collection().insert_one(doc) | |
| return { | |
| "audio_file_id": str(inserted.inserted_id), | |
| "audio_id": audio_id, | |
| "audio_url": audio_url, | |
| "sample_rate": int(sample_rate), | |
| "duration_seconds": float(duration), | |
| "caption_id": caption_id, | |
| "caption": caption, | |
| } | |
| def _generate_audio_from_text( | |
| text: str, | |
| sid: int, | |
| speed: float, | |
| caption_id: Optional[str] = None, | |
| ) -> dict: | |
| tts = get_pretrained_model(ENGLISH_REPO_ID, speed) | |
| audio = tts.generate(text, sid=sid) | |
| if len(audio.samples) == 0: | |
| raise ValueError("No audio was generated.") | |
| return _save_audio_to_db( | |
| audio.samples, | |
| audio.sample_rate, | |
| caption_id=caption_id, | |
| caption=text, | |
| ) | |
| class AudioByIdRequest(BaseModel): | |
| audio_id: str | |
| sid: Optional[int] = 0 | |
| speed: Optional[float] = 1.0 | |
| firebase_id_token: Optional[str] = None | |
| api = FastAPI(title="Text-to-Speech API") | |
| def _api_response(succes: bool, messase: str, data): | |
| return {"succes": succes, "messase": messase, "data": data} | |
| def health_check(): | |
| return {"status": "ok"} | |
| def _find_audio_doc(identifier: str): | |
| doc = _get_mongo_collection().find_one({"audio_id": identifier}) | |
| if doc: | |
| return doc | |
| if ObjectId.is_valid(identifier): | |
| return _get_mongo_collection().find_one({"_id": ObjectId(identifier)}) | |
| return None | |
| def _get_firebase_api_key() -> str: | |
| # Read at runtime so env updates after restarts are reflected. | |
| return ( | |
| os.getenv("FIREBASE_API_KEY", "").strip() | |
| or os.getenv("FIREBASE_WEB_API_KEY", "").strip() | |
| or os.getenv("FIREBASE_APIKEY", "").strip() | |
| ) | |
| async def verify_firebase_token(firebase_id_token: str) -> dict: | |
| """Verify Firebase ID token using REST API.""" | |
| if not firebase_id_token: | |
| raise HTTPException(status_code=401, detail=ERRORS["TOKEN_MISSING"]) | |
| firebase_api_key = _get_firebase_api_key() | |
| if not firebase_api_key: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="FIREBASE_API_KEY is missing in environment configuration", | |
| ) | |
| url = f"https://identitytoolkit.googleapis.com/v1/accounts:lookup?key={firebase_api_key}" | |
| payload = {"idToken": firebase_id_token} | |
| try: | |
| resp = await asyncio.to_thread(requests.post, url, json=payload, timeout=10) | |
| if resp.status_code != 200: | |
| try: | |
| detail = resp.json().get("error", {}).get("message", ERRORS["TOKEN_INVALID"]) | |
| except ValueError: | |
| detail = ERRORS["TOKEN_INVALID"] | |
| raise HTTPException(status_code=401, detail=f"Firebase token verification failed: {detail}") | |
| try: | |
| users = resp.json().get("users", []) | |
| except ValueError: | |
| raise HTTPException(status_code=502, detail="Invalid response from Firebase verification service") | |
| if not users: | |
| raise HTTPException(status_code=401, detail="Firebase token verification failed: no user found") | |
| return users[0] | |
| except requests.RequestException as e: | |
| raise HTTPException(status_code=503, detail=f"Firebase verification service unavailable: {str(e)}") | |
| async def get_audio_by_id( | |
| request: Request, | |
| payload: Optional[AudioByIdRequest] = Body(default=None), | |
| audio_id: Optional[str] = Form(default=None), | |
| sid: Optional[int] = Form(default=0), | |
| speed: Optional[float] = Form(default=1.0), | |
| firebase_id_token: Optional[str] = Form(default=None), | |
| ): | |
| try: | |
| resolved_audio_id = audio_id or (payload.audio_id if payload else None) | |
| resolved_sid = payload.sid if payload and payload.sid is not None else sid | |
| resolved_speed = payload.speed if payload and payload.speed is not None else speed | |
| resolved_firebase_token = firebase_id_token or (payload.firebase_id_token if payload else None) | |
| await verify_firebase_token(resolved_firebase_token) | |
| if not resolved_audio_id: | |
| return _api_response(False, "audio_id is required", None) | |
| doc = _find_audio_doc(resolved_audio_id) | |
| if not doc and ObjectId.is_valid(resolved_audio_id): | |
| doc = _get_mongo_collection().find_one({"caption_id": resolved_audio_id}) | |
| if not doc: | |
| caption_doc = _get_captions_collection().find_one({"_id": ObjectId(resolved_audio_id)}) | |
| if caption_doc: | |
| caption_text = str(caption_doc.get("caption", "")).strip() | |
| if caption_text: | |
| try: | |
| saved = _generate_audio_from_text( | |
| caption_text, | |
| sid=resolved_sid, | |
| speed=resolved_speed, | |
| caption_id=resolved_audio_id, | |
| ) | |
| doc = _find_audio_doc(saved["audio_id"]) | |
| except Exception as e: | |
| log(f"Error generating audio from caption {resolved_audio_id}: {e}") | |
| if not doc: | |
| return _api_response(False, "Audio not found", None) | |
| audio_bytes = bytes(doc.get("audio_file", b"")) | |
| if not audio_bytes: | |
| return _api_response(False, "Document found but audio_file is missing", None) | |
| resolved_id = str(doc.get("audio_id") or doc.get("_id")) | |
| audio_url = str(request.base_url) + f"audio/{resolved_id}.opus" | |
| return _api_response( | |
| True, | |
| "Audio fetched successfully", | |
| { | |
| "audio_id": resolved_id, | |
| "audio_url": audio_url, | |
| "sample_rate": int(doc.get("sample_rate", 0)), | |
| "duration_seconds": float(doc.get("duration_seconds", 0.0)), | |
| "caption": doc.get("caption"), | |
| }, | |
| ) | |
| except HTTPException: | |
| raise | |
| except PyMongoError as e: | |
| log(f"MongoDB error in /audio/by-id: {e}") | |
| raise HTTPException(status_code=503, detail="Database service unavailable") | |
| except Exception as e: | |
| log(f"Unhandled error in /audio/by-id: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| def stream_audio(audio_id: str): | |
| try: | |
| doc = _find_audio_doc(audio_id) | |
| if not doc: | |
| return Response(status_code=404) | |
| audio_bytes = bytes(doc.get("audio_file", b"")) | |
| if not audio_bytes: | |
| return Response(status_code=404) | |
| resolved_id = str(doc.get("audio_id") or doc.get("_id")) | |
| return Response( | |
| content=audio_bytes, | |
| media_type="audio/ogg", | |
| headers={ | |
| "Content-Disposition": f'inline; filename="{resolved_id}.opus"', | |
| "Cache-Control": "public, max-age=31536000", | |
| }, | |
| ) | |
| except PyMongoError as e: | |
| log(f"MongoDB error in /audio/{{audio_id}}.opus: {e}") | |
| return Response(status_code=503) | |
| except Exception as e: | |
| log(f"Unhandled error in /audio/{{audio_id}}.opus: {e}") | |
| return Response(status_code=500) | |
| app = api | |
| log(f"FIREBASE API key present at startup: {bool(_get_firebase_api_key())}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |