scheduler / api /main.py
Owen Kosman
genesis
2f68b3a
from __future__ import annotations
from pathlib import Path
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from ortools.sat.python import cp_model
from core.models import CompileInput, Ontology, RuleIR
from nlp.parser import parse_rules_with_grounding
from solver.builder import build_model
from .schemas import (
CompileRequest,
CompileResponse,
ParseRulesRequest,
ParseRulesResponse,
ProjectCreateRequest,
ProjectCreateResponse,
SolveRequest,
SolveResponse,
)
from .storage import create_project_dir, write_artifact_text, write_artifact_json
app = FastAPI(title="Scheduler Service", version="0.1.0")
@app.post("/projects", response_model=ProjectCreateResponse)
def create_project(req: ProjectCreateRequest) -> ProjectCreateResponse:
project_dir = create_project_dir(req.tenant_id, req.project_id)
return ProjectCreateResponse(tenant_id=req.tenant_id, project_id=project_dir.name)
@app.post("/parse_rules", response_model=ParseRulesResponse)
def parse_rules(req: ParseRulesRequest) -> ParseRulesResponse:
try:
onto = Ontology.model_validate(req.ontology)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid ontology: {e}")
rules: List[RuleIR] = []
unresolved: List[str] = []
warnings: List[str] = []
if req.rules_text:
rules, unresolved, warnings = parse_rules_with_grounding(req.rules_text, onto)
if req.tenant_id and req.project_id and req.persist:
base = create_project_dir(req.tenant_id, req.project_id)
write_artifact_text(base, "rules_text.md", req.rules_text or "")
write_artifact_json(base, "rules_ir.json", [r.model_dump() for r in rules])
write_artifact_json(base, "ontology.json", onto.model_dump())
return ParseRulesResponse(rules=[r.model_dump() for r in rules], unresolved=unresolved, warnings=warnings)
@app.post("/compile", response_model=CompileResponse)
def compile_model(req: CompileRequest) -> CompileResponse:
try:
onto = Ontology.model_validate(req.ontology)
rules = [RuleIR.model_validate(r) for r in req.rules or []]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid input: {e}")
ci = CompileInput(ontology=onto, rules=rules, cohorts=req.cohorts or {})
model, dv, var_map, report = build_model(ci)
# Persist artifacts if requested
if req.tenant_id and req.project_id and req.persist:
base = create_project_dir(req.tenant_id, req.project_id)
write_artifact_json(base, "ontology.json", onto.model_dump())
write_artifact_json(base, "rules_ir.json", [r.model_dump() for r in rules])
write_artifact_json(base, "decision_vars.json", {"x_vars": dv.x_vars, "count": len(dv.x_vars)})
write_artifact_json(base, "compile_report.json", report.model_dump())
return CompileResponse(
x_var_count=len(dv.x_vars),
aux_var_count=len(dv.aux_vars),
report=report.model_dump(),
)
@app.post("/solve", response_model=SolveResponse)
def solve(req: SolveRequest) -> SolveResponse:
try:
onto = Ontology.model_validate(req.ontology)
rules = [RuleIR.model_validate(r) for r in req.rules or []]
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid input: {e}")
ci = CompileInput(ontology=onto, rules=rules, cohorts=req.cohorts or {})
model, dv, var_map, report = build_model(ci)
solver = cp_model.CpSolver()
if req.time_limit_seconds is not None:
solver.parameters.max_time_in_seconds = float(req.time_limit_seconds)
if req.num_workers is not None:
solver.parameters.num_search_workers = int(req.num_workers)
if req.random_seed is not None:
solver.parameters.random_seed = int(req.random_seed)
status = solver.Solve(model)
if status not in (cp_model.OPTIMAL, cp_model.FEASIBLE):
raise HTTPException(status_code=422, detail="Model infeasible or no solution found within limits")
assignments = []
for name, var in var_map.items():
if name.startswith("x[") and solver.BooleanValue(var):
# x[R,B,K]
try:
content = name[name.find("[") + 1 : name.find("]")]
resident_id, block_id, rotation_id = content.split(",")
assignments.append(
{
"resident": resident_id,
"block": block_id,
"rotation": rotation_id,
}
)
except Exception:
continue
# Very simple CSV (resident,block,rotation)
csv_lines = ["resident,block,rotation"]
for a in assignments:
csv_lines.append(f"{a['resident']},{a['block']},{a['rotation']}")
schedule_csv = "\n".join(csv_lines)
# Persist artifacts if requested
if req.tenant_id and req.project_id and req.persist:
base = create_project_dir(req.tenant_id, req.project_id)
write_artifact_json(base, "schedule.json", {"assignments": assignments})
write_artifact_text(base, "schedule.csv", schedule_csv)
return SolveResponse(assignments=assignments, schedule_csv=schedule_csv)