| from greedrl.feature import * |
| from greedrl.variable import * |
| from greedrl import Problem |
|
|
| features = [continuous_feature('task_demand'), |
| continuous_feature('worker_weight_limit'), |
| continuous_feature('distance_matrix'), |
| variable_feature('distance_this_to_task'), |
| variable_feature('distance_task_to_end')] |
|
|
| variables = [task_demand_now('task_demand'), |
| task_demand_now('task_demand_this', feature='task_demand', only_this=True), |
| feature_variable('task_weight'), |
| task_variable('task_weight_this', feature='task_weight'), |
| worker_variable('worker_weight_limit'), |
| worker_used_resource('worker_used_weight', task_require='task_weight'), |
| edge_variable('distance_last_to_this', feature='distance_matrix', last_to_this=True)] |
|
|
|
|
| class Constraint: |
|
|
| def do_task(self): |
| worker_weight_limit = self.worker_weight_limit - self.worker_used_weight |
| return torch.min(self.task_demand_this, worker_weight_limit // self.task_weight_this) |
|
|
| def mask_task(self): |
| |
| mask = self.task_demand <= 0 |
| |
| worker_weight_limit = self.worker_weight_limit - self.worker_used_weight |
| |
| mask |= self.task_weight > worker_weight_limit[:, None] |
| return mask |
|
|
| def finished(self): |
| return torch.all(self.task_demand <= 0, 1) |
|
|
|
|
| class Objective: |
|
|
| def step_worker_end(self): |
| return self.distance_last_to_this |
|
|
| def step_task(self): |
| return self.distance_last_to_this |
|
|
|
|
| def make_problem(batch_count, batch_size=1, task_count=100): |
| assert batch_size == 1 |
|
|
| NT = task_count |
| problem_list = [] |
| for i in range(batch_count): |
| problem = Problem() |
| problem.id = i |
|
|
| problem.worker_weight_limit = [50] |
|
|
| problem.task_demand = torch.randint(1, 10, (NT,), dtype=torch.int64) |
|
|
| |
| problem.task_weight = torch.ones(NT, dtype=torch.int64) |
|
|
| loc = torch.rand(NT + 1, 2, dtype=torch.float32) |
| distance_matrix = torch.norm(loc[:, None, :] - loc[None, :, :], dim=2) * 1000 |
| problem.distance_matrix = distance_matrix.to(torch.int64) |
|
|
| problem.variables = variables |
| problem.constraint = Constraint |
| problem.objective = Objective |
|
|
| problem_list.append(problem) |
|
|
| return problem_list |
|
|
|
|
| if __name__ == '__main__': |
| import sys |
| import os.path as osp |
| sys.path.append(osp.join(osp.dirname(__file__), '../')) |
| import runner |
|
|
| runner.run(make_problem) |
|
|