Spaces:
Sleeping
Sleeping
File size: 4,076 Bytes
6252f54 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 | """Training metrics and control endpoints."""
import asyncio
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from backend.dependencies import get_neo4j_client
from backend.graph.neo4j_client import Neo4jClient
from backend.graph.cypher_queries import GET_TRAINING_METRICS, GET_ENRICHMENT_COVERAGE
log = logging.getLogger(__name__)
router = APIRouter()
_training_tasks: dict[str, str] = {} # run_id → status
class TrainingRequest(BaseModel):
episodes_per_domain: int = 50
domain: str | None = None
@router.get("/training/metrics")
async def training_metrics(neo4j: Annotated[Neo4jClient, Depends(get_neo4j_client)]):
rows = neo4j.run_query(GET_TRAINING_METRICS)
return [
{
"run_id": r.get("run_id"),
"domain_name": r.get("domain_name"),
"sector": r.get("sector"),
"episodes": r.get("episodes"),
"final_reward": r.get("final_reward"),
"avg_reward_last10": r.get("avg_reward_last10"),
"device": r.get("device"),
"policy_version": r.get("policy_version"),
"ts": r.get("ts"),
}
for r in rows
]
@router.get("/training/coverage")
async def training_coverage(neo4j: Annotated[Neo4jClient, Depends(get_neo4j_client)]):
rows = neo4j.run_query(GET_ENRICHMENT_COVERAGE)
return [
{
"domain": r.get("domain"),
"has_standard": bool(r.get("has_standard")),
"standard_enriched": bool(r.get("standard_enriched")),
"has_trend": bool(r.get("has_trend")),
"trend_enriched": bool(r.get("trend_enriched")),
"drl_trained": bool(r.get("drl_trained")),
"drl_reward": r.get("drl_reward"),
"drl_last_trained": r.get("drl_last_trained"),
}
for r in rows
]
@router.post("/training/run")
async def trigger_training(
request: TrainingRequest,
neo4j: Annotated[Neo4jClient, Depends(get_neo4j_client)],
):
run_id = uuid.uuid4().hex[:12]
_training_tasks[run_id] = "started"
async def _run():
try:
import os
from dotenv import load_dotenv
load_dotenv()
from neo4j import GraphDatabase
from pipeline.train_on_graph import GraphTrainer, _NeoJClientAdapter
uri = os.getenv("NEO4J_URI")
username = os.getenv("NEO4J_USERNAME")
password = os.getenv("NEO4J_PASSWORD")
database = os.getenv("NEO4J_DATABASE", "neo4j")
driver = GraphDatabase.driver(uri, auth=(username, password))
try:
trainer = GraphTrainer(driver, database)
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: trainer.run(
episodes_per_domain=request.episodes_per_domain,
target_domain=request.domain,
),
)
_training_tasks[run_id] = f"completed:{result['domains_trained']} domains"
log.info(f"Background training {run_id} complete: {result}")
finally:
driver.close()
except Exception as exc:
_training_tasks[run_id] = f"error:{exc}"
log.exception(f"Background training {run_id} failed")
asyncio.create_task(_run())
return {
"status": "started",
"run_id": run_id,
"message": f"Training {request.episodes_per_domain} eps/domain"
+ (f" on '{request.domain}'" if request.domain else " across all domains")
+ ". Metrics appear in Neo4j as each domain completes.",
}
@router.get("/training/status/{run_id}")
async def training_status(run_id: str):
status = _training_tasks.get(run_id)
if status is None:
raise HTTPException(status_code=404, detail="run_id not found")
return {"run_id": run_id, "status": status}
|