File size: 4,076 Bytes
6252f54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""Training metrics and control endpoints."""

import asyncio
import logging
import uuid
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel

from backend.dependencies import get_neo4j_client
from backend.graph.neo4j_client import Neo4jClient
from backend.graph.cypher_queries import GET_TRAINING_METRICS, GET_ENRICHMENT_COVERAGE

log = logging.getLogger(__name__)
router = APIRouter()

_training_tasks: dict[str, str] = {}  # run_id → status


class TrainingRequest(BaseModel):
    episodes_per_domain: int = 50
    domain: str | None = None


@router.get("/training/metrics")
async def training_metrics(neo4j: Annotated[Neo4jClient, Depends(get_neo4j_client)]):
    rows = neo4j.run_query(GET_TRAINING_METRICS)
    return [
        {
            "run_id": r.get("run_id"),
            "domain_name": r.get("domain_name"),
            "sector": r.get("sector"),
            "episodes": r.get("episodes"),
            "final_reward": r.get("final_reward"),
            "avg_reward_last10": r.get("avg_reward_last10"),
            "device": r.get("device"),
            "policy_version": r.get("policy_version"),
            "ts": r.get("ts"),
        }
        for r in rows
    ]


@router.get("/training/coverage")
async def training_coverage(neo4j: Annotated[Neo4jClient, Depends(get_neo4j_client)]):
    rows = neo4j.run_query(GET_ENRICHMENT_COVERAGE)
    return [
        {
            "domain": r.get("domain"),
            "has_standard": bool(r.get("has_standard")),
            "standard_enriched": bool(r.get("standard_enriched")),
            "has_trend": bool(r.get("has_trend")),
            "trend_enriched": bool(r.get("trend_enriched")),
            "drl_trained": bool(r.get("drl_trained")),
            "drl_reward": r.get("drl_reward"),
            "drl_last_trained": r.get("drl_last_trained"),
        }
        for r in rows
    ]


@router.post("/training/run")
async def trigger_training(
    request: TrainingRequest,
    neo4j: Annotated[Neo4jClient, Depends(get_neo4j_client)],
):
    run_id = uuid.uuid4().hex[:12]
    _training_tasks[run_id] = "started"

    async def _run():
        try:
            import os
            from dotenv import load_dotenv
            load_dotenv()
            from neo4j import GraphDatabase
            from pipeline.train_on_graph import GraphTrainer, _NeoJClientAdapter

            uri = os.getenv("NEO4J_URI")
            username = os.getenv("NEO4J_USERNAME")
            password = os.getenv("NEO4J_PASSWORD")
            database = os.getenv("NEO4J_DATABASE", "neo4j")
            driver = GraphDatabase.driver(uri, auth=(username, password))
            try:
                trainer = GraphTrainer(driver, database)
                loop = asyncio.get_event_loop()
                result = await loop.run_in_executor(
                    None,
                    lambda: trainer.run(
                        episodes_per_domain=request.episodes_per_domain,
                        target_domain=request.domain,
                    ),
                )
                _training_tasks[run_id] = f"completed:{result['domains_trained']} domains"
                log.info(f"Background training {run_id} complete: {result}")
            finally:
                driver.close()
        except Exception as exc:
            _training_tasks[run_id] = f"error:{exc}"
            log.exception(f"Background training {run_id} failed")

    asyncio.create_task(_run())
    return {
        "status": "started",
        "run_id": run_id,
        "message": f"Training {request.episodes_per_domain} eps/domain"
                   + (f" on '{request.domain}'" if request.domain else " across all domains")
                   + ". Metrics appear in Neo4j as each domain completes.",
    }


@router.get("/training/status/{run_id}")
async def training_status(run_id: str):
    status = _training_tasks.get(run_id)
    if status is None:
        raise HTTPException(status_code=404, detail="run_id not found")
    return {"run_id": run_id, "status": status}