blackopsrepl's picture
update
e7cf451
from solverforge_legacy.solver.score import (
constraint_provider,
ConstraintFactory,
Joiners,
HardSoftDecimalScore,
ConstraintCollectors,
)
from datetime import datetime, date, time
from .domain import Employee, Shift
def get_minute_overlap(shift1: Shift, shift2: Shift) -> int:
return (
min(shift1.end, shift2.end) - max(shift1.start, shift2.start)
).total_seconds() // 60
def is_overlapping_with_date(shift: Shift, dt: date) -> bool:
return shift.start.date() == dt or shift.end.date() == dt
def overlapping_in_minutes(
first_start_datetime: datetime,
first_end_datetime: datetime,
second_start_datetime: datetime,
second_end_datetime: datetime,
) -> int:
latest_start = max(first_start_datetime, second_start_datetime)
earliest_end = min(first_end_datetime, second_end_datetime)
delta = (earliest_end - latest_start).total_seconds() / 60
return max(0, delta)
def get_shift_overlapping_duration_in_minutes(shift: Shift, dt: date) -> int:
start_date_time = datetime.combine(dt, datetime.min.time())
end_date_time = datetime.combine(dt, datetime.max.time())
overlap = overlapping_in_minutes(
start_date_time, end_date_time, shift.start, shift.end
)
return int(overlap)
@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
return [
# Hard constraints
required_skill(constraint_factory),
no_overlapping_shifts(constraint_factory),
at_least_10_hours_between_two_shifts(constraint_factory),
one_shift_per_day(constraint_factory),
unavailable_employee(constraint_factory),
# max_shifts_per_employee(constraint_factory), # Optional extension - disabled by default
# Soft constraints
undesired_day_for_employee(constraint_factory),
desired_day_for_employee(constraint_factory),
balance_employee_shift_assignments(constraint_factory),
]
def required_skill(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each(Shift)
.filter(lambda shift: not shift.has_required_skill())
.penalize(HardSoftDecimalScore.ONE_HARD)
.as_constraint("Missing required skill")
)
def no_overlapping_shifts(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each_unique_pair(
Shift,
Joiners.equal(lambda shift: shift.employee.name),
Joiners.overlapping(lambda shift: shift.start, lambda shift: shift.end),
)
.penalize(HardSoftDecimalScore.ONE_HARD, get_minute_overlap)
.as_constraint("Overlapping shift")
)
def at_least_10_hours_between_two_shifts(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each(Shift)
.join(
Shift,
Joiners.equal(lambda shift: shift.employee.name),
Joiners.less_than_or_equal(
lambda shift: shift.end, lambda shift: shift.start
),
)
.filter(
lambda first_shift, second_shift: (
second_shift.start - first_shift.end
).total_seconds()
// (60 * 60)
< 10
)
.penalize(
HardSoftDecimalScore.ONE_HARD,
lambda first_shift, second_shift: 600
- ((second_shift.start - first_shift.end).total_seconds() // 60),
)
.as_constraint("At least 10 hours between 2 shifts")
)
def one_shift_per_day(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each_unique_pair(
Shift,
Joiners.equal(lambda shift: shift.employee.name),
Joiners.equal(lambda shift: shift.start.date()),
)
.penalize(HardSoftDecimalScore.ONE_HARD)
.as_constraint("Max one shift per day")
)
def unavailable_employee(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each(Shift)
.join(
Employee,
Joiners.equal(lambda shift: shift.employee, lambda employee: employee),
)
.flatten_last(lambda employee: employee.unavailable_dates)
.filter(lambda shift, unavailable_date: is_overlapping_with_date(shift, unavailable_date))
.penalize(
HardSoftDecimalScore.ONE_HARD,
lambda shift, unavailable_date: int((min(shift.end, datetime.combine(unavailable_date, time(23, 59, 59))) - max(shift.start, datetime.combine(unavailable_date, time(0, 0, 0)))).total_seconds() / 60),
)
.as_constraint("Unavailable employee")
)
def max_shifts_per_employee(constraint_factory: ConstraintFactory):
"""
Hard constraint: No employee can have more than 12 shifts.
The limit of 12 is chosen based on the demo data dimensions:
- SMALL dataset: 139 shifts / 15 employees = ~9.3 average
- This provides headroom while preventing extreme imbalance
Note: A limit that's too low (e.g., 5) would make the problem infeasible.
Always ensure your constraints are compatible with your data dimensions.
"""
return (
constraint_factory.for_each(Shift)
.group_by(lambda shift: shift.employee, ConstraintCollectors.count())
.filter(lambda employee, shift_count: shift_count > 12)
.penalize(
HardSoftDecimalScore.ONE_HARD,
lambda employee, shift_count: shift_count - 12,
)
.as_constraint("Max 12 shifts per employee")
)
def undesired_day_for_employee(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each(Shift)
.join(
Employee,
Joiners.equal(lambda shift: shift.employee, lambda employee: employee),
)
.flatten_last(lambda employee: employee.undesired_dates)
.filter(lambda shift, undesired_date: shift.is_overlapping_with_date(undesired_date))
.penalize(
HardSoftDecimalScore.ONE_SOFT,
lambda shift, undesired_date: int((min(shift.end, datetime.combine(undesired_date, time(23, 59, 59))) - max(shift.start, datetime.combine(undesired_date, time(0, 0, 0)))).total_seconds() / 60),
)
.as_constraint("Undesired day for employee")
)
def desired_day_for_employee(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each(Shift)
.join(
Employee,
Joiners.equal(lambda shift: shift.employee, lambda employee: employee),
)
.flatten_last(lambda employee: employee.desired_dates)
.filter(lambda shift, desired_date: shift.is_overlapping_with_date(desired_date))
.reward(
HardSoftDecimalScore.ONE_SOFT,
lambda shift, desired_date: int((min(shift.end, datetime.combine(desired_date, time(23, 59, 59))) - max(shift.start, datetime.combine(desired_date, time(0, 0, 0)))).total_seconds() / 60),
)
.as_constraint("Desired day for employee")
)
def balance_employee_shift_assignments(constraint_factory: ConstraintFactory):
return (
constraint_factory.for_each(Shift)
.group_by(lambda shift: shift.employee, ConstraintCollectors.count())
.complement(
Employee, lambda e: 0
) # Include all employees which are not assigned to any shift.
.group_by(
ConstraintCollectors.load_balance(
lambda employee, shift_count: employee,
lambda employee, shift_count: shift_count,
)
)
.penalize_decimal(
HardSoftDecimalScore.ONE_SOFT,
lambda load_balance: load_balance.unfairness(),
)
.as_constraint("Balance employee shift assignments")
)