File size: 3,404 Bytes
6252f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""FastAPI application — startup lifecycle and router registration."""

import logging
import os
from contextlib import asynccontextmanager

import torch
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from backend.config import get_settings
from backend.dependencies import set_neo4j_client, set_llm_client
from backend.graph.neo4j_client import Neo4jClient
from backend.llm.client import LLMClient

log = logging.getLogger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = get_settings()

    # Neo4j
    client = Neo4jClient(
        uri=settings.neo4j_uri,
        username=settings.neo4j_username,
        password=settings.neo4j_password,
        database=settings.neo4j_database,
    )
    set_neo4j_client(client)
    ok = client.verify_connectivity()
    log.info(f"Neo4j connectivity: {'OK' if ok else 'FAILED'}")

    # LLM
    llm = LLMClient(settings=settings)
    set_llm_client(llm)

    # DRL checkpoint (pre-load so first request isn't slow)
    try:
        from backend.drl.trainer import load_trained_policy
        checkpoint = settings.drl_checkpoint_path
        if os.path.exists(checkpoint):
            policy = load_trained_policy(checkpoint)
            app.state.drl_policy = policy
            log.info(f"DRL policy loaded from {checkpoint}")
        else:
            app.state.drl_policy = None
            log.warning(f"DRL checkpoint not found at {checkpoint}; optimizer will use random priorities")
    except Exception as exc:
        log.warning(f"Could not load DRL policy: {exc}")
        app.state.drl_policy = None

    # GPU info
    if torch.cuda.is_available():
        log.info(f"GPU: {torch.cuda.get_device_name(0)}")
        if getattr(torch.version, "hip", None):
            log.info(f"AMD ROCm: {torch.version.hip}")
    else:
        log.info("No GPU detected — running on CPU")

    yield

    client.close()
    log.info("Neo4j driver closed")


def create_app() -> FastAPI:
    app = FastAPI(
        title="AMD Enterprise Architecture Optimizer",
        description="Agentic EA roadmap generation powered by AMD MI300X + ROCm",
        version="1.0.0",
        lifespan=lifespan,
    )

    app.add_middleware(
        CORSMiddleware,
        allow_origins=["*"],
        allow_credentials=True,
        allow_methods=["*"],
        allow_headers=["*"],
    )

    from backend.api.routes_health import router as health_router
    from backend.api.routes_graph import router as graph_router
    from backend.api.routes_analyze import router as analyze_router
    from backend.api.routes_training import router as training_router
    from backend.api.routes_chat import router as chat_router
    from backend.api.routes_integrations import router as integrations_router

    app.include_router(health_router, prefix="/api/v1", tags=["health"])
    app.include_router(graph_router, prefix="/api/v1", tags=["graph"])
    app.include_router(analyze_router, prefix="/api/v1", tags=["analyze"])
    app.include_router(training_router, prefix="/api/v1", tags=["training"])
    app.include_router(chat_router, prefix="/api/v1", tags=["chat"])
    app.include_router(integrations_router, prefix="/api/v1", tags=["integrations"])

    return app


app = create_app()

if __name__ == "__main__":
    import uvicorn
    uvicorn.run("backend.main:app", host="0.0.0.0", port=8080, reload=True)