| import torch | |
| import functools | |
| from .utils import repeat | |
| class VarMeta(object): | |
| def __init__(self, clazz, **kwargs): | |
| self.clazz = clazz | |
| self._kwargs = kwargs | |
| for k, v in kwargs.items(): | |
| setattr(self, k, v) | |
| def __call__(self, problem, batch_size, sample_num): | |
| kwargs = self._kwargs.copy() | |
| kwargs['problem'] = problem.feats | |
| kwargs['batch_size'] = batch_size | |
| kwargs['sample_num'] = sample_num | |
| kwargs['worker_num'] = problem.worker_num | |
| kwargs['task_num'] = problem.task_num | |
| return self.clazz(**kwargs) | |
| def attribute_variable(name, attribute=None): | |
| return VarMeta(AttributeVariable, name=name, attribute=attribute) | |
| class AttributeVariable: | |
| def __init__(self, name, attribute, problem, batch_size, sample_num, worker_num, task_num): | |
| if attribute is None: | |
| attribute = name; | |
| self.name = name | |
| self.value = problem[attribute] | |
| def feature_variable(name, feature=None): | |
| return VarMeta(FeatureVariable, name=name, feature=feature) | |
| class FeatureVariable: | |
| def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): | |
| if feature is None: | |
| feature = name | |
| assert feature == 'id' or feature.startswith("worker_") or feature.startswith("task_") | |
| self.name = name | |
| self.feature = problem[feature] | |
| self.value = repeat(self.feature, sample_num) | |
| def task_variable(name, feature=None): | |
| return VarMeta(TaskVariable, name=name, feature=feature) | |
| class TaskVariable: | |
| def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): | |
| if feature is None: | |
| feature = name | |
| assert feature.startswith("task_") | |
| self.name = name | |
| self.feature = problem[feature] | |
| size = list(self.feature.size()) | |
| size[0] = batch_size | |
| del size[1] | |
| self.value = self.feature.new_zeros(size) | |
| def step_task(self, b_index, p_index, t_index): | |
| self.value[b_index] = self.feature[p_index, t_index] | |
| def worker_variable(name, feature=None): | |
| return VarMeta(WorkerVariable, name=name, feature=feature) | |
| class WorkerVariable: | |
| def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): | |
| if feature is None: | |
| feature = name | |
| assert feature.startswith("worker_") | |
| self.name = name | |
| self.feature = problem[feature] | |
| size = list(self.feature.size()) | |
| size[0] = batch_size | |
| del size[1] | |
| self.value = self.feature.new_zeros(size) | |
| def step_worker_start(self, b_index, p_index, w_index): | |
| self.value[b_index] = self.feature[p_index, w_index] | |
| def worker_task_variable(name, feature=None): | |
| return VarMeta(WorkerTaskVariable, name=name, feature=feature) | |
| class WorkerTaskVariable: | |
| def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): | |
| if feature is None: | |
| feature = name | |
| assert feature.startswith("worker_task_") | |
| self.name = name | |
| self.feature = problem[feature] | |
| size = list(self.feature.size()) | |
| size[0] = batch_size | |
| del size[1] | |
| self._feature = self.feature.new_zeros(size) | |
| del size[2] | |
| self.value = self.feature.new_zeros(size) | |
| def step_worker_start(self, b_index, p_index, w_index): | |
| self._feature[b_index] = self.feature[p_index, w_index] | |
| def step_task(self, b_index, p_index, t_index): | |
| self.value[b_index] = self._feature[b_index, t_index] | |
| def worker_task_group(name, feature=None): | |
| return VarMeta(WorkerTaskGroup, name=name, feature=feature) | |
| class WorkerTaskGroup: | |
| def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): | |
| if feature is None: | |
| feature = name | |
| assert feature.startswith("task_") | |
| self.name = name | |
| self.feature = problem[feature].long() | |
| NG = self.feature.max() + 1 | |
| assert torch.all(self.feature >= 0) | |
| self.value = self.feature.new_zeros(batch_size, NG) | |
| def step_worker_start(self, b_index, p_index, w_index): | |
| self.value[b_index] = 0 | |
| def step_task(self, b_index, p_index, t_index): | |
| group = self.feature[p_index, t_index] | |
| self.value[b_index, group] += 1; | |
| def worker_task_item(name, item_id, item_num): | |
| return VarMeta(WorkerTaskItem, name=name, item_id=item_id, item_num=item_num) | |
| class WorkerTaskItem: | |
| def __init__(self, name, item_id, item_num, problem, batch_size, sample_num, worker_num, task_num): | |
| assert item_id.startswith('task_') | |
| assert item_num.startswith('task_') | |
| self.name = name | |
| self.item_id = repeat(problem[item_id], sample_num).long() | |
| self.item_num = repeat(problem[item_num], sample_num) | |
| assert torch.all(self.item_id >= 0) | |
| size = [0, 0] | |
| size[0] = self.item_id.size(0) | |
| size[1] = self.item_id.max() + 1 | |
| self.value = self.item_num.new_zeros(size) | |
| def step_worker_start(self, b_index, p_index, w_index): | |
| self.value[b_index] = 0 | |
| def step_task(self, b_index, p_index, t_index): | |
| item_id = self.item_id[b_index, t_index] | |
| item_num = self.item_num[b_index, t_index] | |
| self.value[b_index[:, None], item_id] += item_num | |
| def make_feat(self): | |
| NT = self.item_id.size(1) | |
| v = self.value[:, None, :] | |
| v = v.expand(-1, NT, -1) | |
| v = v.gather(2, self.item_id).clamp(0, 1) | |
| v = self.item_num.clamp(0, 1) - v | |
| return v.clamp(0, 1).sum(2) | |
| def task_demand_now(name, feature=None, only_this=False): | |
| return VarMeta(TaskDemandNow, name=name, feature=feature, only_this=only_this) | |
| class TaskDemandNow: | |
| def __init__(self, name, feature, only_this, problem, batch_size, sample_num, worker_num, task_num): | |
| if feature is None: | |
| feature = name | |
| assert feature.startswith("task_") | |
| self.name = name | |
| self.only_this = only_this | |
| self._value = repeat(problem[feature], sample_num) | |
| assert self._value.dtype in \ | |
| (torch.int8, torch.int16, torch.int32, torch.int64) | |
| assert torch.all(self._value >= 0) | |
| if only_this: | |
| size = self._value.size(0) | |
| self.value = self._value.new_zeros(size) | |
| else: | |
| self.value = self._value | |
| def step_task(self, b_index, p_index, t_index, done): | |
| if done is not None: | |
| self._value[b_index, t_index] -= done | |
| if self.only_this: | |
| self.value[b_index] = self._value[b_index, t_index] | |
| else: | |
| self.value = self._value | |
| def worker_count_now(name, feature=None): | |
| return VarMeta(WorkerCountNow, name=name, feature=feature) | |
| class WorkerCountNow: | |
| def __init__(self, name, feature, problem, batch_size, sample_num, worker_num, task_num): | |
| if feature is None: | |
| feature = name | |
| assert feature.startswith("worker_") | |
| self.name = name | |
| self.value = repeat(problem[feature], sample_num) | |
| assert self.value.dtype in \ | |
| (torch.int8, torch.int16, torch.int32, torch.int64) | |
| assert torch.all(self.value >= 0) | |
| def step_worker_start(self, b_index, p_index, w_index): | |
| self.value[b_index, w_index] -= 1 | |
| def edge_variable(name, feature, last_to_this=False, | |
| this_to_task=False, task_to_end=False, last_to_loop=False): | |
| return VarMeta(EdgeVariable, name=name, feature=feature, | |
| last_to_this=last_to_this, this_to_task=this_to_task, task_to_end=task_to_end, | |
| last_to_loop=last_to_loop) | |
| class EdgeVariable: | |
| def __init__(self, name, feature, last_to_this, this_to_task, task_to_end, last_to_loop, | |
| problem, batch_size, sample_num, worker_num, task_num): | |
| assert feature.endswith("_matrix") | |
| flags = [last_to_this, this_to_task, task_to_end, last_to_loop] | |
| assert flags.count(True) == 1 and flags.count(False) == 3 | |
| if feature is None: | |
| feature = name | |
| self.name = name | |
| self.last_to_this = last_to_this | |
| self.this_to_task = this_to_task | |
| self.task_to_end = task_to_end | |
| self.last_to_loop = last_to_loop | |
| self.worker_num = worker_num | |
| self.task_num = task_num | |
| self.feature = problem[feature] | |
| size = list(self.feature.size()) | |
| size[0] = batch_size | |
| del size[1:3] | |
| if self.this_to_task or self.task_to_end: | |
| size.insert(1, task_num) | |
| self.value = self.feature.new_zeros(size) | |
| else: | |
| self.value = self.feature.new_zeros(size) | |
| self.end_index = self.feature.new_zeros(size[0], dtype=torch.int64) | |
| self.loop_index = self.feature.new_zeros(size[0], dtype=torch.int64) | |
| self.last_index = self.feature.new_zeros(size[0], dtype=torch.int64) | |
| self.task_index = (torch.arange(task_num) + worker_num * 2)[None, :] | |
| def step_worker_start(self, b_index, p_index, w_index): | |
| if self.last_to_this: | |
| self.value[b_index] = 0 | |
| self.last_index[b_index] = w_index | |
| elif self.this_to_task: | |
| self.do_this_to_task(b_index, p_index, w_index) | |
| elif self.task_to_end: | |
| self.end_index[b_index] = w_index + self.worker_num | |
| self.do_task_to_end(b_index, p_index) | |
| elif self.last_to_loop: | |
| self.value[b_index] = 0 | |
| self.last_index[b_index] = w_index | |
| def step_worker_end(self, b_index, p_index, w_index): | |
| this_index = w_index + self.worker_num | |
| if self.last_to_this: | |
| self.do_last_to_this(b_index, p_index, this_index) | |
| elif self.this_to_task: | |
| self.do_this_to_task(b_index, p_index, this_index) | |
| elif self.task_to_end: | |
| pass | |
| elif self.last_to_loop: | |
| self.do_last_to_loop(b_index, p_index) | |
| def step_task(self, b_index, p_index, t_index): | |
| this_index = t_index + self.worker_num * 2 | |
| if self.last_to_this: | |
| self.do_last_to_this(b_index, p_index, this_index) | |
| self.last_index[b_index] = this_index | |
| elif self.this_to_task: | |
| self.do_this_to_task(b_index, p_index, this_index) | |
| elif self.task_to_end: | |
| pass | |
| elif self.last_to_loop: | |
| last_index = self.last_index[b_index] | |
| loop_index = self.loop_index[b_index] | |
| self.loop_index[b_index] = torch.where(last_index < self.worker_num, this_index, loop_index) | |
| self.last_index[b_index] = this_index | |
| def do_last_to_this(self, b_index, p_index, this_index): | |
| last_index = self.last_index[b_index] | |
| self.value[b_index] = self.feature[p_index, last_index, this_index] | |
| def do_this_to_task(self, b_index, p_index, this_index): | |
| p_index2 = p_index[:, None] | |
| this_index2 = this_index[:, None] | |
| task_index2 = self.task_index | |
| self.value[b_index] = self.feature[p_index2, this_index2, task_index2] | |
| def do_task_to_end(self, b_index, p_index): | |
| p_index2 = p_index[:, None] | |
| task_index2 = self.task_index | |
| end_index = self.end_index[b_index] | |
| end_index2 = end_index[:, None] | |
| self.value[b_index] = self.feature[p_index2, task_index2, end_index2] | |
| def do_last_to_loop(self, b_index, p_index): | |
| loop_index = self.loop_index[b_index] | |
| last_index = self.last_index[b_index] | |
| self.value[b_index] = self.feature[p_index, last_index, loop_index] | |
| def make_feat(self): | |
| assert self.this_to_task or self.task_to_end, \ | |
| "one of [this_to_task, task_to_end] must be true" | |
| return self.value.clone() | |
| def worker_used_resource(name, edge_require=None, task_require=None, task_ready=None, worker_ready=None, task_due=None): | |
| return VarMeta(WorkerUsedResource, name=name, edge_require=edge_require, task_require=task_require, | |
| task_ready=task_ready, worker_ready=worker_ready, task_due=task_due) | |
| class WorkerUsedResource: | |
| def __init__(self, name, edge_require, task_require, task_ready, worker_ready, task_due, | |
| problem, batch_size, sample_num, worker_num, task_num): | |
| assert edge_require is None or edge_require.endswith("_matrix"), "unsupported edge: {}".format(edge_require) | |
| assert task_require is None or task_require.startswith("task_"), "unsupported task_require: {}".format( | |
| task_require) | |
| assert task_ready is None or task_ready.startswith("task_"), "unsupported task_service: {}".format(task_ready) | |
| assert worker_ready is None or worker_ready.startswith("worker_") and not worker_ready.startswith( | |
| "worker_task_") | |
| assert task_due is None or task_due.startswith("task_"), "unsupported task_due: {}".format(task_due) | |
| self.name = name | |
| self.worker_num = worker_num | |
| self.task_num = task_num | |
| if edge_require is None: | |
| self.edge_require = None | |
| else: | |
| self.edge_require = problem[edge_require] | |
| self.last_index = self.edge_require.new_zeros(batch_size, dtype=torch.int64) | |
| if task_require is None: | |
| self.task_require = None | |
| else: | |
| self.task_require = problem[task_require] | |
| self.task_require2 = repeat(self.task_require, sample_num) | |
| if task_ready is None: | |
| self.task_ready = None | |
| else: | |
| self.task_ready = problem[task_ready] | |
| if worker_ready is None: | |
| self.worker_ready = None | |
| else: | |
| self.worker_ready = problem[worker_ready] | |
| if task_due is None: | |
| self.task_due = None | |
| else: | |
| self.task_due = problem[task_due] | |
| tenors = [self.edge_require, self.task_require, self.task_ready, self.worker_ready] | |
| tenors = list(filter(lambda x: x is not None, tenors)) | |
| assert tenors, "at least one of edge_require, task_require, task_ready, worker_ready is required!" | |
| size = list(tenors[0].size()) | |
| size[0] = batch_size | |
| if self.edge_require is None: | |
| del size[1] | |
| else: | |
| del size[1:3] | |
| self.value = tenors[0].new_zeros(size) | |
| def step_worker_start(self, b_index, p_index, w_index): | |
| if self.worker_ready is None: | |
| self.value[b_index] = 0 | |
| else: | |
| self.value[b_index] = self.worker_ready[p_index, w_index] | |
| if self.edge_require is not None: | |
| self.last_index[b_index] = w_index | |
| def step_worker_end(self, b_index, p_index, w_index): | |
| if self.edge_require is not None: | |
| last_index = self.last_index[b_index] | |
| this_index = w_index + self.worker_num | |
| self.value[b_index] += self.edge_require[p_index, last_index, this_index] | |
| self.last_index[b_index] = this_index; | |
| def step_task(self, b_index, p_index, t_index, done): | |
| if done is None: | |
| if self.edge_require is not None: | |
| last_index = self.last_index[b_index] | |
| this_index = t_index + (self.worker_num * 2) | |
| self.value[b_index] += self.edge_require[p_index, last_index, this_index] | |
| self.last_index[b_index] = this_index | |
| if self.task_ready is not None: | |
| self.value[b_index] = torch.max(self.value[b_index], self.task_ready[p_index, t_index]) | |
| else: | |
| if self.task_require is not None: | |
| if self.value.dim() == 2: | |
| done = done[:, None] | |
| self.value[b_index] += self.task_require[p_index, t_index] * done | |
| def make_feat(self): | |
| assert self.value.dim() == 2, \ | |
| "value's dim must be 2, actual: {}".format(self.value.dim()) | |
| assert self.task_require is not None, "task_require is required" | |
| v = self.value[:, None, :] + self.task_require2 | |
| return v.clamp(0, 1).sum(2, dtype=v.dtype) | |
| def worker_task_sequence(name): | |
| return VarMeta(WorkerTaskSequence, name=name) | |
| class WorkerTaskSequence: | |
| def __init__(self, name, problem, batch_size, sample_num, worker_num, task_num): | |
| self.name = name | |
| self.value = None | |
| def step_finish(self, worker_task_seq): | |
| self.value = worker_task_seq | |