File size: 7,783 Bytes
e510416
 
 
 
 
 
 
e7cf451
e510416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7cf451
e510416
 
e7cf451
e510416
 
 
 
 
 
 
 
 
 
 
e7cf451
e510416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7cf451
e510416
 
e7cf451
e510416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7cf451
e510416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7cf451
e510416
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from solverforge_legacy.solver.score import (
    constraint_provider,
    ConstraintFactory,
    Joiners,
    HardSoftDecimalScore,
    ConstraintCollectors,
)
from datetime import datetime, date, time

from .domain import Employee, Shift


def get_minute_overlap(shift1: Shift, shift2: Shift) -> int:
    return (
        min(shift1.end, shift2.end) - max(shift1.start, shift2.start)
    ).total_seconds() // 60


def is_overlapping_with_date(shift: Shift, dt: date) -> bool:
    return shift.start.date() == dt or shift.end.date() == dt


def overlapping_in_minutes(
    first_start_datetime: datetime,
    first_end_datetime: datetime,
    second_start_datetime: datetime,
    second_end_datetime: datetime,
) -> int:
    latest_start = max(first_start_datetime, second_start_datetime)
    earliest_end = min(first_end_datetime, second_end_datetime)
    delta = (earliest_end - latest_start).total_seconds() / 60
    return max(0, delta)


def get_shift_overlapping_duration_in_minutes(shift: Shift, dt: date) -> int:
    start_date_time = datetime.combine(dt, datetime.min.time())
    end_date_time = datetime.combine(dt, datetime.max.time())
    overlap = overlapping_in_minutes(
        start_date_time, end_date_time, shift.start, shift.end
    )
    return int(overlap)


@constraint_provider
def define_constraints(constraint_factory: ConstraintFactory):
    return [
        # Hard constraints
        required_skill(constraint_factory),
        no_overlapping_shifts(constraint_factory),
        at_least_10_hours_between_two_shifts(constraint_factory),
        one_shift_per_day(constraint_factory),
        unavailable_employee(constraint_factory),
        # max_shifts_per_employee(constraint_factory),  # Optional extension - disabled by default
        # Soft constraints
        undesired_day_for_employee(constraint_factory),
        desired_day_for_employee(constraint_factory),
        balance_employee_shift_assignments(constraint_factory),
    ]


def required_skill(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each(Shift)
        .filter(lambda shift: not shift.has_required_skill())
        .penalize(HardSoftDecimalScore.ONE_HARD)
        .as_constraint("Missing required skill")
    )


def no_overlapping_shifts(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each_unique_pair(
            Shift,
            Joiners.equal(lambda shift: shift.employee.name),
            Joiners.overlapping(lambda shift: shift.start, lambda shift: shift.end),
        )
        .penalize(HardSoftDecimalScore.ONE_HARD, get_minute_overlap)
        .as_constraint("Overlapping shift")
    )


def at_least_10_hours_between_two_shifts(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each(Shift)
        .join(
            Shift,
            Joiners.equal(lambda shift: shift.employee.name),
            Joiners.less_than_or_equal(
                lambda shift: shift.end, lambda shift: shift.start
            ),
        )
        .filter(
            lambda first_shift, second_shift: (
                second_shift.start - first_shift.end
            ).total_seconds()
            // (60 * 60)
            < 10
        )
        .penalize(
            HardSoftDecimalScore.ONE_HARD,
            lambda first_shift, second_shift: 600
            - ((second_shift.start - first_shift.end).total_seconds() // 60),
        )
        .as_constraint("At least 10 hours between 2 shifts")
    )


def one_shift_per_day(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each_unique_pair(
            Shift,
            Joiners.equal(lambda shift: shift.employee.name),
            Joiners.equal(lambda shift: shift.start.date()),
        )
        .penalize(HardSoftDecimalScore.ONE_HARD)
        .as_constraint("Max one shift per day")
    )


def unavailable_employee(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each(Shift)
        .join(
            Employee,
            Joiners.equal(lambda shift: shift.employee, lambda employee: employee),
        )
        .flatten_last(lambda employee: employee.unavailable_dates)
        .filter(lambda shift, unavailable_date: is_overlapping_with_date(shift, unavailable_date))
        .penalize(
            HardSoftDecimalScore.ONE_HARD,
            lambda shift, unavailable_date: int((min(shift.end, datetime.combine(unavailable_date, time(23, 59, 59))) - max(shift.start, datetime.combine(unavailable_date, time(0, 0, 0)))).total_seconds() / 60),
        )
        .as_constraint("Unavailable employee")
    )


def max_shifts_per_employee(constraint_factory: ConstraintFactory):
    """
    Hard constraint: No employee can have more than 12 shifts.

    The limit of 12 is chosen based on the demo data dimensions:
    - SMALL dataset: 139 shifts / 15 employees = ~9.3 average
    - This provides headroom while preventing extreme imbalance

    Note: A limit that's too low (e.g., 5) would make the problem infeasible.
    Always ensure your constraints are compatible with your data dimensions.
    """
    return (
        constraint_factory.for_each(Shift)
        .group_by(lambda shift: shift.employee, ConstraintCollectors.count())
        .filter(lambda employee, shift_count: shift_count > 12)
        .penalize(
            HardSoftDecimalScore.ONE_HARD,
            lambda employee, shift_count: shift_count - 12,
        )
        .as_constraint("Max 12 shifts per employee")
    )


def undesired_day_for_employee(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each(Shift)
        .join(
            Employee,
            Joiners.equal(lambda shift: shift.employee, lambda employee: employee),
        )
        .flatten_last(lambda employee: employee.undesired_dates)
        .filter(lambda shift, undesired_date: shift.is_overlapping_with_date(undesired_date))
        .penalize(
            HardSoftDecimalScore.ONE_SOFT,
            lambda shift, undesired_date: int((min(shift.end, datetime.combine(undesired_date, time(23, 59, 59))) - max(shift.start, datetime.combine(undesired_date, time(0, 0, 0)))).total_seconds() / 60),
        )
        .as_constraint("Undesired day for employee")
    )


def desired_day_for_employee(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each(Shift)
        .join(
            Employee,
            Joiners.equal(lambda shift: shift.employee, lambda employee: employee),
        )
        .flatten_last(lambda employee: employee.desired_dates)
        .filter(lambda shift, desired_date: shift.is_overlapping_with_date(desired_date))
        .reward(
            HardSoftDecimalScore.ONE_SOFT,
            lambda shift, desired_date: int((min(shift.end, datetime.combine(desired_date, time(23, 59, 59))) - max(shift.start, datetime.combine(desired_date, time(0, 0, 0)))).total_seconds() / 60),
        )
        .as_constraint("Desired day for employee")
    )


def balance_employee_shift_assignments(constraint_factory: ConstraintFactory):
    return (
        constraint_factory.for_each(Shift)
        .group_by(lambda shift: shift.employee, ConstraintCollectors.count())
        .complement(
            Employee, lambda e: 0
        )  # Include all employees which are not assigned to any shift.
        .group_by(
            ConstraintCollectors.load_balance(
                lambda employee, shift_count: employee,
                lambda employee, shift_count: shift_count,
            )
        )
        .penalize_decimal(
            HardSoftDecimalScore.ONE_SOFT,
            lambda load_balance: load_balance.unfairness(),
        )
        .as_constraint("Balance employee shift assignments")
    )