scheduler / solver /builder.py
Owen Kosman
improved the parser
6b671d9
from __future__ import annotations
from typing import Dict, List, Tuple, Iterable, Set
from ortools.sat.python import cp_model
from core.models import CompileInput, DecisionVars, CompileReport, RuleIR, Window, Rotation
def build_model(ci: CompileInput) -> Tuple[cp_model.CpModel, DecisionVars, Dict[str, cp_model.IntVar], CompileReport]:
model = cp_model.CpModel()
residents = [r.id for r in ci.ontology.residents]
blocks = ci.ontology.blocks
rotations = [r.id for r in ci.ontology.rotations]
# Create feasibility variables x[r,b,k] for all triples; a real system would filter by eligibility.
var_map: Dict[str, cp_model.IntVar] = {}
x_vars: List[str] = []
for r in residents:
for b in blocks:
# Sparse creation: here we still create all, but this is where eligibility filtering would prune.
choices = []
for k in rotations:
name = f"x[{r},{b},{k}]"
v = model.NewBoolVar(name)
var_map[name] = v
x_vars.append(name)
choices.append(v)
# One-of constraint
model.Add(sum(choices) == 1)
# Compile RuleIRs into constraints
coverage_constraints = 0
total_constraints = 0
lock_constraints = 0
eligibility_forced_zero = 0
consecutive_constraints = 0
rest_constraints = 0
pairing_constraints = 0
preference_penalties = 0
unsupported_rules = []
feasibility_warnings = []
rule_to_constraints = {}
# Soft penalty variables and objective terms
penalty_vars = [] # List of IntVar penalty variables for soft rules
penalty_weights = [] # Corresponding weights
def blocks_from_windows(windows: List[Window]) -> List[str]:
if not windows:
return list(blocks)
selected: List[str] = []
for w in windows:
# explicit list
if w.blocks:
for blk in w.blocks:
if blk in blocks and blk not in selected:
selected.append(blk)
# range by ids if present
if w.start_block and w.end_block and w.start_block in blocks and w.end_block in blocks:
start_idx = blocks.index(w.start_block)
end_idx = blocks.index(w.end_block)
lo, hi = (start_idx, end_idx) if start_idx <= end_idx else (end_idx, start_idx)
for blk in blocks[lo : hi + 1]:
if blk not in selected:
selected.append(blk)
return selected or list(blocks)
def get_var(rid: str, blk: str, rot: str) -> cp_model.IntVar | None:
name = f"x[{rid},{blk},{rot}]"
return var_map.get(name)
for rule in ci.rules:
try:
if rule.constraint_type == "rotation_total":
# params: { rotation: str, count: int }
rotation = str(rule.params.get("rotation")) if rule.params else None
count = int(rule.params.get("count")) if rule.params and "count" in rule.params else None
if not rotation or count is None:
continue
tgt_blocks = blocks_from_windows(rule.windows)
sel_residents: Iterable[str] = (
[str(x) for x in (rule.selectors.get("residents") or [])]
if rule.selectors
else []
)
if not sel_residents:
sel_residents = residents
for rid in sel_residents:
terms: List[cp_model.IntVar] = []
for blk in tgt_blocks:
v = get_var(rid, blk, rotation)
if v is not None:
terms.append(v)
if terms:
model.Add(sum(terms) == count)
total_constraints += 1
elif rule.constraint_type == "coverage_min_max":
# params: { rotation: str, min/min_count/min_required: int|None, max/max_count/max_required: int|None }
rotation = str(rule.params.get("rotation")) if rule.params else None
# Handle multiple possible parameter names for min/max
min_req = (rule.params.get("min") or
rule.params.get("min_count") or
rule.params.get("min_required")) if rule.params else None
max_req = (rule.params.get("max") or
rule.params.get("max_count") or
rule.params.get("max_required")) if rule.params else None
if not rotation:
continue
# Normalize rotation names (handle common variations)
rotation_map = {
"Night": "NIGHTS",
"Nights": "NIGHTS",
"Electives": "ELECTIVES",
"ELECTIVE": "ELECTIVES"
}
rotation = rotation_map.get(rotation, rotation)
tgt_blocks = blocks_from_windows(rule.windows)
for blk in tgt_blocks:
terms: List[cp_model.IntVar] = []
for rid in residents:
v = get_var(rid, blk, rotation)
if v is not None:
terms.append(v)
if not terms:
continue
if min_req is not None:
model.Add(sum(terms) >= int(min_req))
coverage_constraints += 1
if max_req is not None:
model.Add(sum(terms) <= int(max_req))
coverage_constraints += 1
elif rule.constraint_type == "eligibility":
# params/selectors support: forbid specific rotation(s) for selected residents over windows
# selectors: { residents: [..], rotations: [..] }
sel_residents: Iterable[str] = (
[str(x) for x in (rule.selectors.get("residents") or [])]
if rule.selectors
else []
)
sel_rotations: Iterable[str] = (
[str(x) for x in (rule.selectors.get("rotations") or [])]
if rule.selectors
else []
)
if not sel_residents:
sel_residents = residents
if not sel_rotations:
# If unspecified, nothing to enforce
continue
tgt_blocks = blocks_from_windows(rule.windows)
for rid in sel_residents:
for blk in tgt_blocks:
for rot in sel_rotations:
v = get_var(rid, blk, rot)
if v is not None:
model.Add(v == 0)
eligibility_forced_zero += 1
elif rule.constraint_type == "lock_assignment":
# params: { locks: [ { resident, block, rotation } ] }
if not rule.params:
continue
locks = rule.params.get("locks") or []
for item in locks:
rid = str(item.get("resident"))
blk = str(item.get("block"))
rot = str(item.get("rotation"))
v = get_var(rid, blk, rot)
if v is not None:
model.Add(v == 1)
lock_constraints += 1
rule_to_constraints.setdefault(rule.rule_id, []).append(f"Lock {rid} to {rot} in {blk}")
elif rule.constraint_type == "run_consecutive":
# params: { rotation: str, max_consecutive: int }
rotation = str(rule.params.get("rotation")) if rule.params else None
max_consec = rule.params.get("max_consecutive") if rule.params else None
if not rotation or max_consec is None:
continue
max_consec = int(max_consec)
tgt_blocks = blocks_from_windows(rule.windows)
sel_residents: Iterable[str] = (
[str(x) for x in (rule.selectors.get("residents") or [])]
if rule.selectors
else residents
)
# Sliding window constraint: sum over any max_consec+1 consecutive blocks <= max_consec
for rid in sel_residents:
for i in range(len(tgt_blocks) - max_consec):
window_blocks = tgt_blocks[i : i + max_consec + 1]
terms = []
for blk in window_blocks:
v = get_var(rid, blk, rotation)
if v is not None:
terms.append(v)
if terms:
if rule.hardness == "hard":
model.Add(sum(terms) <= max_consec)
consecutive_constraints += 1
else:
# Soft constraint: add penalty for violation
violation_var = model.NewIntVar(0, len(terms), f"consec_penalty_{rule.rule_id}_{rid}_{i}")
model.Add(violation_var >= sum(terms) - max_consec)
penalty_vars.append(violation_var)
penalty_weights.append(rule.weight)
consecutive_constraints += 1
rule_to_constraints.setdefault(rule.rule_id, []).append(f"Max {max_consec} consecutive {rotation}")
elif rule.constraint_type == "rest_gap":
# params: { from_session: str, to_session: str, min_hours: int }
# Simple version: forbid Night(block_i) + AM(block_i+1) for same resident
from_session = rule.params.get("from_session", "Night") if rule.params else "Night"
to_session = rule.params.get("to_session", "AM") if rule.params else "AM"
if from_session == "Night" and to_session == "AM":
tgt_blocks = blocks_from_windows(rule.windows)
sel_residents: Iterable[str] = (
[str(x) for x in (rule.selectors.get("residents") or [])]
if rule.selectors
else residents
)
for rid in sel_residents:
for i in range(len(tgt_blocks) - 1):
# Find rotations that have Night and AM sessions
night_vars = []
am_vars = []
for rot_obj in ci.ontology.rotations:
if "Night" in rot_obj.sessions:
v = get_var(rid, tgt_blocks[i], rot_obj.id)
if v is not None:
night_vars.append(v)
if "AM" in rot_obj.sessions:
v = get_var(rid, tgt_blocks[i + 1], rot_obj.id)
if v is not None:
am_vars.append(v)
# Forbid Night(i) + AM(i+1)
if night_vars and am_vars:
if rule.hardness == "hard":
model.Add(sum(night_vars) + sum(am_vars) <= 1)
rest_constraints += 1
else:
violation_var = model.NewIntVar(0, 1, f"rest_penalty_{rule.rule_id}_{rid}_{i}")
model.Add(violation_var >= sum(night_vars) + sum(am_vars) - 1)
penalty_vars.append(violation_var)
penalty_weights.append(rule.weight)
rest_constraints += 1
rule_to_constraints.setdefault(rule.rule_id, []).append(f"Rest gap: no {from_session}{to_session}")
elif rule.constraint_type == "preference_weight":
# params: { preferences: [ { resident, rotation, weight } ] }
if not rule.params:
continue
preferences = rule.params.get("preferences", [])
tgt_blocks = blocks_from_windows(rule.windows)
for pref in preferences:
rid = str(pref.get("resident"))
rot = str(pref.get("rotation"))
weight = pref.get("weight", 1)
if rid and rot:
terms = []
for blk in tgt_blocks:
v = get_var(rid, blk, rot)
if v is not None:
terms.append(v)
if terms:
# Add penalty/reward based on preference weight (negative = reward, positive = penalty)
pref_var = model.NewIntVar(0, len(terms), f"pref_{rule.rule_id}_{rid}_{rot}")
model.Add(pref_var == sum(terms))
penalty_vars.append(pref_var)
penalty_weights.append(int(weight))
preference_penalties += 1
rule_to_constraints.setdefault(rule.rule_id, []).append("Preference weights applied")
except Exception as e:
# best-effort compilation; track unsupported/malformed rules
unsupported_rules.append(rule.rule_id)
continue
dv = DecisionVars(x_vars=x_vars)
report = CompileReport(
num_residents=len(residents),
num_blocks=len(blocks),
num_rotations=len(rotations),
num_x_vars=len(x_vars),
num_aux_vars=len(penalty_vars),
coverage_constraints=coverage_constraints,
total_constraints=total_constraints,
lock_constraints=lock_constraints,
eligibility_forced_zero=eligibility_forced_zero,
consecutive_constraints=consecutive_constraints,
rest_constraints=rest_constraints,
pairing_constraints=pairing_constraints,
preference_penalties=preference_penalties,
unsupported_rules=unsupported_rules,
feasibility_warnings=feasibility_warnings,
rule_to_constraints=rule_to_constraints,
)
# Add soft rule objective: minimize weighted penalties
if penalty_vars:
objective_terms = []
for penalty_var, weight in zip(penalty_vars, penalty_weights):
objective_terms.append(penalty_var * weight)
model.Minimize(sum(objective_terms))
# Track penalty variables in var_map for solution extraction
for i, penalty_var in enumerate(penalty_vars):
var_map[f"penalty_{i}"] = penalty_var
# Update decision vars to include penalty vars
dv.aux_vars = [f"penalty_{i}" for i in range(len(penalty_vars))]
return model, dv, var_map, report