| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Module for computational graph execution. |
| | |
| | Classes: |
| | Task: Abstract base class representing a computational task. |
| | Executor: Class for scheduling and executing directed acyclic task graphs. |
| | """ |
| |
|
| | from abc import ABC, abstractmethod |
| | from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union |
| |
|
| | import networkx |
| | import torch |
| | import tqdm |
| | from pydantic import BaseModel |
| | from typing_extensions import Generic, TypeVar |
| |
|
| | ValueT = TypeVar("ValueT") |
| |
|
| |
|
| | class Task(ABC, BaseModel, Generic[ValueT], frozen=True): |
| | """ |
| | Abstract base class representing a task in a computational graph. |
| | |
| | This class should be extended to define specific tasks. Each task can have arguments (dependencies) and a defined execution strategy. |
| | |
| | Attributes: |
| | Generic[ValueT] (TypeVar): The type of the value that the task returns upon execution. |
| | |
| | Methods: |
| | arguments: Abstract method to define task arguments (dependencies). |
| | execute: Abstract method to execute the task. |
| | priority: Returns the priority of the task for scheduling purposes. |
| | group_label: Returns an optional label for task grouping. |
| | """ |
| |
|
| | @abstractmethod |
| | def arguments(self) -> Dict[str, "Task"]: |
| | """ |
| | Returns a dictionary of arguments required for this task. The keys of the dictionary |
| | are argument names, and the values are Task instances. These keys correspond to the |
| | keyword argument names expected by the execute method. |
| | |
| | For example, if this method returns {'input1': taskA, 'input2': taskB}, the execute |
| | method should expect to be called as execute(input1=valueA, input2=valueB), where |
| | valueA and valueB are the outputs of taskA and taskB respectively. |
| | |
| | Returns: |
| | Dict[str, "Task"]: A dictionary mapping argument names to Task instances. |
| | """ |
| | ... |
| |
|
| | @abstractmethod |
| | def execute(self, **kwargs) -> ValueT: |
| | """ |
| | Executes the task using the results of its dependencies. |
| | |
| | The keyword arguments (**kwargs) for this method are dynamically determined based on |
| | the dictionary returned by the 'arguments' method. Each key in the 'arguments' method's |
| | return dictionary becomes a keyword argument in this method, with its value being |
| | the result of the corresponding task's execution. |
| | |
| | Returns: |
| | ValueT: The result of the task execution. |
| | """ |
| | ... |
| |
|
| | def priority(self) -> int: |
| | """ |
| | Returns the priority of the task for scheduling. |
| | |
| | Higher numbers indicate higher priority. Default is 0. |
| | |
| | Returns: |
| | int: The priority of the task. |
| | """ |
| | return 0 |
| |
|
| | def group_label(self) -> Optional[str]: |
| | """ |
| | Returns an optional label used for grouping tasks together. |
| | |
| | Returns: |
| | Optional[str]: The group label of the task, if any. |
| | """ |
| | return None |
| |
|
| | def uses_accelerator(self) -> bool: |
| | """ |
| | Returns True if the task can take advantage of matrix operation |
| | acceleration (such as on a GPU). |
| | """ |
| | return False |
| |
|
| |
|
| | class Executor: |
| | """ |
| | Schedules and executes a set of tasks and their dependencies. |
| | |
| | Handles scheduling, execution, the movement of data between devices, and the lifecycle of intermediate results. |
| | |
| | Attributes: |
| | math_device (torch.device): Device used for tensor computations. |
| | storage_device (torch.device): Device used for storing intermediate results. |
| | targets (List[Task]): List of target tasks to be executed. |
| | schedule (List[Task]): Calculated execution schedule of tasks. |
| | dependencies (Dict[Task, Set[Task]]): Dependencies of each task. |
| | """ |
| |
|
| | math_device: torch.device |
| | storage_device: torch.device |
| | targets: List[Task] |
| | schedule: List[Task] |
| | dependencies: Dict[Task, Set[Task]] |
| |
|
| | def __init__( |
| | self, |
| | tasks: List[Task], |
| | math_device: torch.device = torch.device("cpu"), |
| | storage_device: torch.device = torch.device("cpu"), |
| | ): |
| | """ |
| | Initializes the Executor with a list of tasks and device configurations. |
| | |
| | Args: |
| | tasks (List[Task]): The list of tasks to be executed. |
| | math_device (torch.device, optional): The device for tensor computations. Defaults to CPU. |
| | storage_device (torch.device, optional): The device for storing results. Defaults to CPU. |
| | """ |
| | self.math_device = math_device |
| | self.storage_device = storage_device |
| | self.schedule = self._make_schedule(tasks) |
| | self.targets = tasks |
| |
|
| | def run(self) -> Iterator[Tuple[Task, Any]]: |
| | """ |
| | Execute the computed schedule and yield the target values. |
| | |
| | Yields: |
| | Iterator[Tuple[Task, Any]]: An iterator of task-result pairs. |
| | """ |
| | |
| | last_use_index = {} |
| | for idx, task in reversed(list(enumerate(self.schedule))): |
| | for t in self.dependencies[task]: |
| | if t not in last_use_index: |
| | last_use_index[t] = idx |
| | if task not in last_use_index: |
| | last_use_index[task] = idx |
| |
|
| | values: Dict[Task, Any] = {} |
| | for idx, task in tqdm.tqdm(enumerate(self.schedule), total=len(self.schedule)): |
| | use_math_device = task.uses_accelerator() |
| |
|
| | arguments = {} |
| | for name, dep in task.arguments().items(): |
| | value = values[dep] |
| |
|
| | |
| | if use_math_device: |
| | if ( |
| | isinstance(value, torch.Tensor) |
| | and value.device != self.math_device |
| | ): |
| | value = value.to(self.math_device) |
| | elif isinstance(value, dict): |
| | for key in value: |
| | if ( |
| | isinstance(value[key], torch.Tensor) |
| | and value[key].device != self.math_device |
| | ): |
| | value[key] = value[key].to(self.math_device) |
| |
|
| | arguments[name] = value |
| | del value |
| |
|
| | res = task.execute(**arguments) |
| | del arguments |
| |
|
| | if isinstance(res, torch.Tensor) and res.device != self.storage_device: |
| | res = res.to(self.storage_device) |
| |
|
| | values[task] = res |
| | del res |
| |
|
| | if task in self.targets: |
| | yield (task, values[task]) |
| |
|
| | |
| | expired = [] |
| | for key in values: |
| | if idx >= last_use_index[key]: |
| | expired.append(key) |
| |
|
| | for key in expired: |
| | del values[key] |
| |
|
| | def execute(self) -> None: |
| | """ |
| | Execute all tasks and discard results. |
| | """ |
| | for task, value in self.run(): |
| | pass |
| |
|
| | DUMMY_TASK_VALUE = "!!DUMMY!!" |
| |
|
| | def _make_schedule(self, targets: List[Task]) -> List[Task]: |
| | self.schedule = [] |
| | self.dependencies = self._build_dependencies(targets) |
| |
|
| | edge_tups = [] |
| | for node in self.dependencies: |
| | for dependency in self.dependencies[node]: |
| | edge_tups.append((dependency, node)) |
| |
|
| | for task in targets: |
| | |
| | |
| | edge_tups.append((Executor.DUMMY_TASK_VALUE, task)) |
| |
|
| | def _compare_key(task: Union[Task, str]): |
| | if task == Executor.DUMMY_TASK_VALUE: |
| | return ("", 0) |
| | return ( |
| | task.group_label() or "", |
| | -task.priority(), |
| | ) |
| |
|
| | graph = networkx.DiGraph(edge_tups) |
| | res = [ |
| | t |
| | for t in networkx.lexicographical_topological_sort(graph, key=_compare_key) |
| | if t != Executor.DUMMY_TASK_VALUE |
| | ] |
| | return res |
| |
|
| | def _build_dependencies(self, targets: List[Task]) -> Dict[Task, Set[Task]]: |
| | task_dependencies: Dict[Task, Set[Task]] = {} |
| | to_process = list(targets) |
| | while to_process: |
| | child = to_process.pop() |
| | if child in task_dependencies: |
| | continue |
| |
|
| | task_dependencies[child] = set() |
| | for _, dep in child.arguments().items(): |
| | task_dependencies[child].add(dep) |
| | to_process.append(dep) |
| | return task_dependencies |
| |
|