File size: 2,042 Bytes
047d92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f961e0d
 
047d92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f961e0d
047d92b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import os

from dotenv import load_dotenv

# ===========================
# !!! ATTENTION !!!
# KEEP THIS AT THE TOP TO ENSURE ENVIRONMENT VARIABLES ARE LOADED BEFORE ANY IMPORTS
# ===========================
load_dotenv()

from contextlib import asynccontextmanager

from alembic import command
from alembic.config import Config
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger

from src.configs import DatabaseConfig
from src.entities import api_router as api_router_entities
from src.controllers import api_router

from seed_vector_data import seed_vector_data



def run_upgrade(connection, alembic_config: Config):
    alembic_config.attributes["connection"] = connection
    command.upgrade(alembic_config, "head")


async def run_migrations():
    logger.info("Running migrations if any...")
    alembic_config = Config("alembic.ini")
    alembic_config.set_main_option(
        "sqlalchemy.url", os.getenv("SQLALCHEMY_DATABASE_URI")
    )
    async with DatabaseConfig.get_engine().begin() as session:
        await session.run_sync(run_upgrade, alembic_config)


@asynccontextmanager
async def lifespan(app: FastAPI):
    try:
        logger.info("Starting up the application...")
        await run_migrations()
        await seed_vector_data()
        logger.info("Application started successfully...")
        yield
    except Exception as e:
        logger.exception(e)
        raise
    finally:
        logger.info("Application shutdown complete.")


app = FastAPI(lifespan=lifespan)


app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        origin.strip()
        for origin in os.getenv(
            "CORS_ALLOW_ORIGINS", "http://localhost, http://127.0.0.1"
        ).split(",")
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
async def check_health():
    return {"response": "Service is healthy!"}


app.include_router(api_router_entities, prefix="/api")
app.include_router(api_router, prefix="/api")