File size: 3,837 Bytes
0db822c
 
 
 
 
 
9308938
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9308938
 
 
 
 
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import logging
from contextlib import asynccontextmanager

from fastapi import FastAPI, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse

from api.config import settings
from api.routers.transcription import router as transcription_router
from api.schemas import HealthResponse

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(name)s | %(message)s",
)
logger = logging.getLogger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI):
    # ── Startup ───────────────────────────────────────────────────────────────
    from src.inference.transcribe import WhisperTranscriber
    from src.inference.analyze_call import CallAnalyzer

    logger.info("Loading Whisper model from: %s", settings.model_path)
    try:
        app.state.transcriber = await run_in_threadpool(
            WhisperTranscriber, settings.model_path, settings.device
        )
        logger.info("Whisper model loaded successfully.")
    except Exception:
        logger.exception("Failed to load Whisper model β€” /transcribe endpoints will return 503.")
        app.state.transcriber = None

    if settings.gemini_api_key:
        logger.info("Initialising Gemini analyzer (model=%s).", settings.gemini_model)
        try:
            app.state.analyzer = await run_in_threadpool(
                CallAnalyzer, settings.gemini_api_key
            )
            logger.info("Gemini analyzer ready.")
        except Exception:
            logger.exception("Failed to init Gemini β€” /corrected and /analyze will return 503.")
            app.state.analyzer = None
    else:
        logger.warning(
            "GEMINI_API_KEY is not set. "
            "POST /api/v1/transcribe/corrected and /analyze are disabled."
        )
        app.state.analyzer = None

    yield

    # ── Shutdown ──────────────────────────────────────────────────────────────
    logger.info("Shutting down β€” releasing model resources.")
    app.state.transcriber = None
    app.state.analyzer = None


app = FastAPI(
    title="Speech-to-Text API",
    description=(
        "Arabic speech transcription powered by a fine-tuned Whisper model, "
        "with optional Gemini post-processing for speaker diarisation, "
        "phonetic correction, and real estate call analysis."
    ),
    version="1.0.0",
    lifespan=lifespan,
    docs_url="/docs",
    redoc_url="/redoc",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["POST", "GET"],
    allow_headers=["*"],
)

app.include_router(transcription_router)


@app.get("/", include_in_schema=False)
async def root() -> RedirectResponse:
    return RedirectResponse(url="/docs")


@app.get("/health", response_model=HealthResponse, tags=["system"])
async def health(request: Request) -> HealthResponse:
    transcriber = getattr(request.app.state, "transcriber", None)
    analyzer = getattr(request.app.state, "analyzer", None)
    return HealthResponse(
        status="ok" if transcriber is not None else "degraded",
        whisper_loaded=transcriber is not None,
        gemini_available=analyzer is not None,
        model_path=settings.model_path,
    )


@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse:
    logger.exception("Unhandled exception for %s %s", request.method, request.url.path)
    return JSONResponse(status_code=500, content={"detail": "Internal server error."})