| | from functools import wraps |
| | from inspect import unwrap |
| | from typing import Callable, List |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | |
| | def inplace_wrapper(fn: Callable) -> Callable: |
| | """ |
| | Convenience wrapper for passes which modify an object inplace. This |
| | wrapper makes them return the modified object instead. |
| | |
| | Args: |
| | fn (Callable[Object, Any]) |
| | |
| | Returns: |
| | wrapped_fn (Callable[Object, Object]) |
| | """ |
| |
|
| | @wraps(fn) |
| | def wrapped_fn(gm): |
| | val = fn(gm) |
| | return gm |
| |
|
| | return wrapped_fn |
| |
|
| | def log_hook(fn: Callable, level=logging.INFO) -> Callable: |
| | """ |
| | Logs callable output. |
| | |
| | This is useful for logging output of passes. Note inplace_wrapper replaces |
| | the pass output with the modified object. If we want to log the original |
| | output, apply this wrapper before inplace_wrapper. |
| | |
| | |
| | ``` |
| | def my_pass(d: Dict) -> bool: |
| | changed = False |
| | if 'foo' in d: |
| | d['foo'] = 'bar' |
| | changed = True |
| | return changed |
| | |
| | pm = PassManager( |
| | passes=[ |
| | inplace_wrapper(log_hook(my_pass)) |
| | ] |
| | ) |
| | ``` |
| | |
| | Args: |
| | fn (Callable[Type1, Type2]) |
| | level: logging level (e.g. logging.INFO) |
| | |
| | Returns: |
| | wrapped_fn (Callable[Type1, Type2]) |
| | """ |
| | @wraps(fn) |
| | def wrapped_fn(gm): |
| | val = fn(gm) |
| | logger.log(level, f"Ran pass {fn}\t Return value: {val}",) |
| | return val |
| |
|
| | return wrapped_fn |
| |
|
| |
|
| |
|
| | def loop_pass(base_pass: Callable, n_iter: int = None, predicate: Callable = None): |
| | """ |
| | Convenience wrapper for passes which need to be applied multiple times. |
| | |
| | Exactly one of `n_iter`or `predicate` must be specified. |
| | |
| | Args: |
| | base_pass (Callable[Object, Object]): pass to be applied in loop |
| | n_iter (int, optional): number of times to loop pass |
| | predicate (Callable[Object, bool], optional): |
| | |
| | """ |
| | assert (n_iter is not None) ^ ( |
| | predicate is not None |
| | ), "Exactly one of `n_iter`or `predicate` must be specified." |
| |
|
| | @wraps(base_pass) |
| | def new_pass(source): |
| | output = source |
| | if n_iter is not None and n_iter > 0: |
| | for _ in range(n_iter): |
| | output = base_pass(output) |
| | elif predicate is not None: |
| | while predicate(output): |
| | output = base_pass(output) |
| | else: |
| | raise RuntimeError( |
| | f"loop_pass must be given positive int n_iter (given " |
| | f"{n_iter}) xor predicate (given {predicate})" |
| | ) |
| | return output |
| |
|
| | return new_pass |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | def _validate_pass_schedule_constraint( |
| | constraint: Callable[[Callable, Callable], bool], passes: List[Callable] |
| | ): |
| | for i, a in enumerate(passes): |
| | for j, b in enumerate(passes[i + 1 :]): |
| | if constraint(a, b): |
| | continue |
| | raise RuntimeError( |
| | f"pass schedule constraint violated. Expected {a} before {b}" |
| | f" but found {a} at index {i} and {b} at index{j} in pass" |
| | f" list." |
| | ) |
| |
|
| |
|
| | def this_before_that_pass_constraint(this: Callable, that: Callable): |
| | """ |
| | Defines a partial order ('depends on' function) where `this` must occur |
| | before `that`. |
| | """ |
| |
|
| | def depends_on(a: Callable, b: Callable): |
| | if a == that and b == this: |
| | return False |
| | return True |
| |
|
| | return depends_on |
| |
|
| |
|
| | def these_before_those_pass_constraint(these: Callable, those: Callable): |
| | """ |
| | Defines a partial order ('depends on' function) where `these` must occur |
| | before `those`. Where the inputs are 'unwrapped' before comparison. |
| | |
| | For example, the following pass list and constraint list would be invalid. |
| | ``` |
| | passes = [ |
| | loop_pass(pass_b, 3), |
| | loop_pass(pass_a, 5), |
| | ] |
| | |
| | constraints = [ |
| | these_before_those_pass_constraint(pass_a, pass_b) |
| | ] |
| | ``` |
| | |
| | Args: |
| | these (Callable): pass which should occur first |
| | those (Callable): pass which should occur later |
| | |
| | Returns: |
| | depends_on (Callable[[Object, Object], bool] |
| | """ |
| |
|
| | def depends_on(a: Callable, b: Callable): |
| | if unwrap(a) == those and unwrap(b) == these: |
| | return False |
| | return True |
| |
|
| | return depends_on |
| |
|
| |
|
| | class PassManager: |
| | """ |
| | Construct a PassManager. |
| | |
| | Collects passes and constraints. This defines the pass schedule, manages |
| | pass constraints and pass execution. |
| | |
| | Args: |
| | passes (Optional[List[Callable]]): list of passes. A pass is a |
| | callable which modifies an object and returns modified object |
| | constraint (Optional[List[Callable]]): list of constraints. A |
| | constraint is a callable which takes two passes (A, B) and returns |
| | True if A depends on B and False otherwise. See implementation of |
| | `this_before_that_pass_constraint` for example. |
| | """ |
| |
|
| | passes: List[Callable] = [] |
| | constraints: List[Callable] = [] |
| | _validated: bool = False |
| |
|
| | def __init__( |
| | self, |
| | passes=None, |
| | constraints=None, |
| | ): |
| | if passes: |
| | self.passes = passes |
| | if constraints: |
| | self.constraints = constraints |
| |
|
| | @classmethod |
| | def build_from_passlist(cls, passes): |
| | pm = PassManager(passes) |
| | |
| | return pm |
| |
|
| | def add_pass(self, _pass: Callable): |
| | self.passes.append(_pass) |
| | self._validated = False |
| |
|
| | def add_constraint(self, constraint): |
| | self.constraints.append(constraint) |
| | self._validated = False |
| |
|
| | def remove_pass(self, _passes: List[Callable]): |
| | if _passes is None: |
| | return |
| | passes_left = [] |
| | for ps in self.passes: |
| | if ps.__name__ not in _passes: |
| | passes_left.append(ps) |
| | self.passes = passes_left |
| | self._validated = False |
| |
|
| | def validate(self): |
| | """ |
| | Validates that current pass schedule defined by `self.passes` is valid |
| | according to all constraints in `self.constraints` |
| | """ |
| | if self._validated: |
| | return |
| | for constraint in self.constraints: |
| | _validate_pass_schedule_constraint(constraint, self.passes) |
| | self._validated = True |
| |
|
| | def __call__(self, source): |
| | self.validate() |
| | out = source |
| | for _pass in self.passes: |
| | out = _pass(out) |
| | return out |
| |
|