File size: 9,815 Bytes
855c74b
 
87cdad5
855c74b
 
 
 
 
 
 
 
87cdad5
855c74b
 
 
 
87cdad5
e88ed83
855c74b
 
 
 
 
 
 
 
 
 
87cdad5
 
 
 
855c74b
 
 
 
 
 
 
 
 
 
 
e88ed83
855c74b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9307cdc
855c74b
 
9307cdc
855c74b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9307cdc
855c74b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87cdad5
855c74b
 
 
 
 
 
 
 
 
e88ed83
 
 
 
 
855c74b
 
 
 
 
 
 
 
 
b5adb8f
 
 
 
 
 
 
 
 
87cdad5
 
 
 
b5adb8f
 
9e1c6c2
 
 
 
87cdad5
b5adb8f
87cdad5
 
 
 
 
e88ed83
 
 
 
87cdad5
 
e88ed83
 
 
 
87cdad5
 
 
 
 
 
 
855c74b
87cdad5
855c74b
 
 
 
 
87cdad5
855c74b
e88ed83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855c74b
 
e88ed83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855c74b
 
 
 
e88ed83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
855c74b
 
e88ed83
b5adb8f
855c74b
 
 
e88ed83
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
#!/usr/bin/env python3

import asyncio
import io
import os
import uuid
from datetime import datetime
from functools import lru_cache
from typing import Optional

import pymongo
import requests
import soundfile as sf
from bson.binary import Binary
from bson.objectid import ObjectId
from dotenv import load_dotenv
from fastapi import Body, FastAPI, Form, HTTPException, Request, Response
from pymongo.errors import PyMongoError
from pydantic import BaseModel

from model import ENGLISH_REPO_ID, get_pretrained_model

load_dotenv()

MONGO_URI = os.getenv("MONGO_URI", "").strip()
MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech").strip()
MONGO_COLLECTION = os.getenv("MONGO_COLLECTION", "audio").strip()
MONGO_CAPTIONS_COLLECTION = os.getenv("MONGO_CAPTIONS_COLLECTION", "captions").strip()
ERRORS = {
    "TOKEN_MISSING": "firebase_id_token is missing",
    "TOKEN_INVALID": "Invalid Firebase token",
}


def log(msg: str) -> None:
    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
    print(f"{now}: {msg}")


@lru_cache(maxsize=1)
def _get_mongo_client():
    if not MONGO_URI:
        raise ValueError("MONGO_URI is missing in .env")
    return pymongo.MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)


def _get_mongo_collection():
    client = _get_mongo_client()
    return client[MONGO_DB_NAME][MONGO_COLLECTION]


def _get_captions_collection():
    client = _get_mongo_client()
    return client[MONGO_DB_NAME][MONGO_CAPTIONS_COLLECTION]


def _as_opus_bytes(samples, sample_rate: int) -> bytes:
    buffer = io.BytesIO()
    sf.write(buffer, samples, samplerate=sample_rate, format="OGG", subtype="OPUS")
    return buffer.getvalue()


def _save_audio_to_db(
    samples,
    sample_rate: int,
    caption_id: Optional[str] = None,
    caption: Optional[str] = None,
) -> dict:
    audio_id = str(uuid.uuid4())
    duration = len(samples) / sample_rate
    opus_bytes = _as_opus_bytes(samples, sample_rate)
    audio_url = f"/audio/{audio_id}.opus"
    doc = {
        "audio_id": audio_id,
        "audio_url": audio_url,
        "audio_file": Binary(opus_bytes),
        "sample_rate": int(sample_rate),
        "duration_seconds": float(duration),
        "audio_format": "opus",
        "created_at": datetime.utcnow(),
    }
    if caption_id:
        doc["caption_id"] = caption_id
    if caption:
        doc["caption"] = caption

    inserted = _get_mongo_collection().insert_one(doc)
    return {
        "audio_file_id": str(inserted.inserted_id),
        "audio_id": audio_id,
        "audio_url": audio_url,
        "sample_rate": int(sample_rate),
        "duration_seconds": float(duration),
        "caption_id": caption_id,
        "caption": caption,
    }


def _generate_audio_from_text(
    text: str,
    sid: int,
    speed: float,
    caption_id: Optional[str] = None,
) -> dict:
    tts = get_pretrained_model(ENGLISH_REPO_ID, speed)
    audio = tts.generate(text, sid=sid)
    if len(audio.samples) == 0:
        raise ValueError("No audio was generated.")
    return _save_audio_to_db(
        audio.samples,
        audio.sample_rate,
        caption_id=caption_id,
        caption=text,
    )


class AudioByIdRequest(BaseModel):
    audio_id: str
    sid: Optional[int] = 0
    speed: Optional[float] = 1.0
    firebase_id_token: Optional[str] = None


api = FastAPI(title="Text-to-Speech API")


def _api_response(succes: bool, messase: str, data):
    return {"succes": succes, "messase": messase, "data": data}


@api.get("/health")
def health_check():
    return {"status": "ok"}


def _find_audio_doc(identifier: str):
    doc = _get_mongo_collection().find_one({"audio_id": identifier})
    if doc:
        return doc
    if ObjectId.is_valid(identifier):
        return _get_mongo_collection().find_one({"_id": ObjectId(identifier)})
    return None


def _get_firebase_api_key() -> str:
    # Read at runtime so env updates after restarts are reflected.
    return (
        os.getenv("FIREBASE_API_KEY", "").strip()
        or os.getenv("FIREBASE_WEB_API_KEY", "").strip()
        or os.getenv("FIREBASE_APIKEY", "").strip()
    )


async def verify_firebase_token(firebase_id_token: str) -> dict:
    """Verify Firebase ID token using REST API."""
    if not firebase_id_token:
        raise HTTPException(status_code=401, detail=ERRORS["TOKEN_MISSING"])
    firebase_api_key = _get_firebase_api_key()
    if not firebase_api_key:
        raise HTTPException(
            status_code=500,
            detail="FIREBASE_API_KEY is missing in environment configuration",
        )

    url = f"https://identitytoolkit.googleapis.com/v1/accounts:lookup?key={firebase_api_key}"
    payload = {"idToken": firebase_id_token}

    try:
        resp = await asyncio.to_thread(requests.post, url, json=payload, timeout=10)
        if resp.status_code != 200:
            try:
                detail = resp.json().get("error", {}).get("message", ERRORS["TOKEN_INVALID"])
            except ValueError:
                detail = ERRORS["TOKEN_INVALID"]
            raise HTTPException(status_code=401, detail=f"Firebase token verification failed: {detail}")

        try:
            users = resp.json().get("users", [])
        except ValueError:
            raise HTTPException(status_code=502, detail="Invalid response from Firebase verification service")
        if not users:
            raise HTTPException(status_code=401, detail="Firebase token verification failed: no user found")
        return users[0]
    except requests.RequestException as e:
        raise HTTPException(status_code=503, detail=f"Firebase verification service unavailable: {str(e)}")


@api.post("/audio/by-id")
async def get_audio_by_id(
    request: Request,
    payload: Optional[AudioByIdRequest] = Body(default=None),
    audio_id: Optional[str] = Form(default=None),
    sid: Optional[int] = Form(default=0),
    speed: Optional[float] = Form(default=1.0),
    firebase_id_token: Optional[str] = Form(default=None),
):
    try:
        resolved_audio_id = audio_id or (payload.audio_id if payload else None)
        resolved_sid = payload.sid if payload and payload.sid is not None else sid
        resolved_speed = payload.speed if payload and payload.speed is not None else speed
        resolved_firebase_token = firebase_id_token or (payload.firebase_id_token if payload else None)

        await verify_firebase_token(resolved_firebase_token)

        if not resolved_audio_id:
            return _api_response(False, "audio_id is required", None)

        doc = _find_audio_doc(resolved_audio_id)
        if not doc and ObjectId.is_valid(resolved_audio_id):
            doc = _get_mongo_collection().find_one({"caption_id": resolved_audio_id})
            if not doc:
                caption_doc = _get_captions_collection().find_one({"_id": ObjectId(resolved_audio_id)})
                if caption_doc:
                    caption_text = str(caption_doc.get("caption", "")).strip()
                    if caption_text:
                        try:
                            saved = _generate_audio_from_text(
                                caption_text,
                                sid=resolved_sid,
                                speed=resolved_speed,
                                caption_id=resolved_audio_id,
                            )
                            doc = _find_audio_doc(saved["audio_id"])
                        except Exception as e:
                            log(f"Error generating audio from caption {resolved_audio_id}: {e}")

        if not doc:
            return _api_response(False, "Audio not found", None)

        audio_bytes = bytes(doc.get("audio_file", b""))
        if not audio_bytes:
            return _api_response(False, "Document found but audio_file is missing", None)

        resolved_id = str(doc.get("audio_id") or doc.get("_id"))
        audio_url = str(request.base_url) + f"audio/{resolved_id}.opus"

        return _api_response(
            True,
            "Audio fetched successfully",
            {
                "audio_id": resolved_id,
                "audio_url": audio_url,
                "sample_rate": int(doc.get("sample_rate", 0)),
                "duration_seconds": float(doc.get("duration_seconds", 0.0)),
                "caption": doc.get("caption"),
            },
        )
    except HTTPException:
        raise
    except PyMongoError as e:
        log(f"MongoDB error in /audio/by-id: {e}")
        raise HTTPException(status_code=503, detail="Database service unavailable")
    except Exception as e:
        log(f"Unhandled error in /audio/by-id: {e}")
        raise HTTPException(status_code=500, detail="Internal server error")


@api.get("/audio/{audio_id}.opus")
def stream_audio(audio_id: str):
    try:
        doc = _find_audio_doc(audio_id)
        if not doc:
            return Response(status_code=404)

        audio_bytes = bytes(doc.get("audio_file", b""))
        if not audio_bytes:
            return Response(status_code=404)

        resolved_id = str(doc.get("audio_id") or doc.get("_id"))
        return Response(
            content=audio_bytes,
            media_type="audio/ogg",
            headers={
                "Content-Disposition": f'inline; filename="{resolved_id}.opus"',
                "Cache-Control": "public, max-age=31536000",
            },
        )
    except PyMongoError as e:
        log(f"MongoDB error in /audio/{{audio_id}}.opus: {e}")
        return Response(status_code=503)
    except Exception as e:
        log(f"Unhandled error in /audio/{{audio_id}}.opus: {e}")
        return Response(status_code=500)


app = api
log(f"FIREBASE API key present at startup: {bool(_get_firebase_api_key())}")


if __name__ == "__main__":
    import uvicorn

    uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)