| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import List, Optional |
|
|
| from fastapi import FastAPI, HTTPException |
| from ortools.sat.python import cp_model |
|
|
| from core.models import CompileInput, Ontology, RuleIR |
| from nlp.parser import parse_rules_with_grounding |
| from solver.builder import build_model |
| from .schemas import ( |
| CompileRequest, |
| CompileResponse, |
| ParseRulesRequest, |
| ParseRulesResponse, |
| ProjectCreateRequest, |
| ProjectCreateResponse, |
| SolveRequest, |
| SolveResponse, |
| ) |
| from .storage import create_project_dir, write_artifact_text, write_artifact_json |
|
|
|
|
| app = FastAPI(title="Scheduler Service", version="0.1.0") |
|
|
|
|
| @app.post("/projects", response_model=ProjectCreateResponse) |
| def create_project(req: ProjectCreateRequest) -> ProjectCreateResponse: |
| project_dir = create_project_dir(req.tenant_id, req.project_id) |
| return ProjectCreateResponse(tenant_id=req.tenant_id, project_id=project_dir.name) |
|
|
|
|
| @app.post("/parse_rules", response_model=ParseRulesResponse) |
| def parse_rules(req: ParseRulesRequest) -> ParseRulesResponse: |
| try: |
| onto = Ontology.model_validate(req.ontology) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid ontology: {e}") |
|
|
| rules: List[RuleIR] = [] |
| unresolved: List[str] = [] |
| warnings: List[str] = [] |
| if req.rules_text: |
| rules, unresolved, warnings = parse_rules_with_grounding(req.rules_text, onto) |
|
|
| if req.tenant_id and req.project_id and req.persist: |
| base = create_project_dir(req.tenant_id, req.project_id) |
| write_artifact_text(base, "rules_text.md", req.rules_text or "") |
| write_artifact_json(base, "rules_ir.json", [r.model_dump() for r in rules]) |
| write_artifact_json(base, "ontology.json", onto.model_dump()) |
|
|
| return ParseRulesResponse(rules=[r.model_dump() for r in rules], unresolved=unresolved, warnings=warnings) |
|
|
|
|
| @app.post("/compile", response_model=CompileResponse) |
| def compile_model(req: CompileRequest) -> CompileResponse: |
| try: |
| onto = Ontology.model_validate(req.ontology) |
| rules = [RuleIR.model_validate(r) for r in req.rules or []] |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid input: {e}") |
|
|
| ci = CompileInput(ontology=onto, rules=rules, cohorts=req.cohorts or {}) |
| model, dv, var_map, report = build_model(ci) |
|
|
| |
| if req.tenant_id and req.project_id and req.persist: |
| base = create_project_dir(req.tenant_id, req.project_id) |
| write_artifact_json(base, "ontology.json", onto.model_dump()) |
| write_artifact_json(base, "rules_ir.json", [r.model_dump() for r in rules]) |
| write_artifact_json(base, "decision_vars.json", {"x_vars": dv.x_vars, "count": len(dv.x_vars)}) |
| write_artifact_json(base, "compile_report.json", report.model_dump()) |
|
|
| return CompileResponse( |
| x_var_count=len(dv.x_vars), |
| aux_var_count=len(dv.aux_vars), |
| report=report.model_dump(), |
| ) |
|
|
|
|
| @app.post("/solve", response_model=SolveResponse) |
| def solve(req: SolveRequest) -> SolveResponse: |
| try: |
| onto = Ontology.model_validate(req.ontology) |
| rules = [RuleIR.model_validate(r) for r in req.rules or []] |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Invalid input: {e}") |
|
|
| ci = CompileInput(ontology=onto, rules=rules, cohorts=req.cohorts or {}) |
| model, dv, var_map, report = build_model(ci) |
|
|
| solver = cp_model.CpSolver() |
| if req.time_limit_seconds is not None: |
| solver.parameters.max_time_in_seconds = float(req.time_limit_seconds) |
| if req.num_workers is not None: |
| solver.parameters.num_search_workers = int(req.num_workers) |
| if req.random_seed is not None: |
| solver.parameters.random_seed = int(req.random_seed) |
|
|
| status = solver.Solve(model) |
| if status not in (cp_model.OPTIMAL, cp_model.FEASIBLE): |
| raise HTTPException(status_code=422, detail="Model infeasible or no solution found within limits") |
|
|
| assignments = [] |
| for name, var in var_map.items(): |
| if name.startswith("x[") and solver.BooleanValue(var): |
| |
| try: |
| content = name[name.find("[") + 1 : name.find("]")] |
| resident_id, block_id, rotation_id = content.split(",") |
| assignments.append( |
| { |
| "resident": resident_id, |
| "block": block_id, |
| "rotation": rotation_id, |
| } |
| ) |
| except Exception: |
| continue |
|
|
| |
| csv_lines = ["resident,block,rotation"] |
| for a in assignments: |
| csv_lines.append(f"{a['resident']},{a['block']},{a['rotation']}") |
| schedule_csv = "\n".join(csv_lines) |
|
|
| |
| if req.tenant_id and req.project_id and req.persist: |
| base = create_project_dir(req.tenant_id, req.project_id) |
| write_artifact_json(base, "schedule.json", {"assignments": assignments}) |
| write_artifact_text(base, "schedule.csv", schedule_csv) |
|
|
| return SolveResponse(assignments=assignments, schedule_csv=schedule_csv) |
|
|
|
|
|
|
|
|