File size: 4,744 Bytes
ac5551d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e3f1b
ac5551d
99e3f1b
ac5551d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e3f1b
ac5551d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99e3f1b
ac5551d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
main.py β€” FastAPI application entry point.
Wires together all modules, registers middleware/routes, manages lifespan.
"""
from __future__ import annotations

import os
import sys

# Ensure backend root is in sys.path to resolve 'backend.*' imports correctly
# when running from the 'backend' directory.
backend_root = os.path.dirname(os.path.abspath(__file__))
if backend_root not in sys.path:
    sys.path.insert(0, backend_root)

import asyncio
from contextlib import asynccontextmanager
from typing import AsyncIterator

import traceback

from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

from api.routes import models as models_router
from api.routes import sync as sync_router
from api.routes import datasets as datasets_router
from config import settings
from database.connection import close_db, get_db
from middleware.logging_middleware import RequestLoggingMiddleware
from observability.logger import configure_logging, get_logger

# ── Logging bootstrap (must be first) ─────────────────────────────────────────
configure_logging()
log = get_logger("main")


# ── Lifespan ──────────────────────────────────────────────────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
    # Startup
    settings.ensure_dirs()
    log.info("startup", host=settings.host, port=settings.port, version=settings.version)
    await get_db()   # Bootstrap DB / run migrations
    log.info("database_ready", path=str(settings.db_path))

    if settings.auto_sync_on_startup:
        from registry.registry import count_models

        current = await count_models()
        if current == 0:
            from api.routes.sync import _run_full_sync

            log.info("auto_sync_startup_triggered")
            asyncio.create_task(_run_full_sync())

    yield  # ← app runs

    # Shutdown
    await close_db()
    log.info("shutdown")


# ── Application ───────────────────────────────────────────────────────────────
app = FastAPI(
    title="MLForge Cloud Registry",
    version=settings.version,
    description="Global Model and Dataset Discovery Service β€” The Brain of MLForge.",
    docs_url="/docs",
    redoc_url="/redoc",
    lifespan=lifespan,
)


@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
    # Log full traceback for debugging 500s.
    log.error(
        "unhandled_exception",
        path=request.url.path,
        error=str(exc),
        traceback=traceback.format_exc(),
    )
    return JSONResponse(
        status_code=500,
        content={"detail": "Internal Server Error", "error": str(exc)},
    )

# ── Middleware ─────────────────────────────────────────────────────────────────
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins for the cloud registry to support SDK/CLI/UI
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)
app.add_middleware(RequestLoggingMiddleware)

# ── Routes ────────────────────────────────────────────────────────────────────
app.include_router(models_router.router)
app.include_router(sync_router.router)
app.include_router(datasets_router.router)


@app.get("/health", tags=["system"])
async def health() -> dict:
    from registry.registry import count_models
    from datasets.registry import count_datasets
    n_models = await count_models()
    n_datasets = await count_datasets()
    return {
        "status": "ok",
        "service": "cloud_registry",
        "version": settings.version,
        "model_count": n_models,
        "dataset_count": n_datasets,
    }


# ── Dev runner ────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "main:app",
        host=settings.host,
        port=settings.port,
        reload=settings.debug,
        log_config=None,  # We use structlog
    )