| | |
| | |
| | """ |
| | Non signal processing related utilities. |
| | """ |
| |
|
| | import inspect |
| | import typing as tp |
| | import sys |
| | import time |
| |
|
| |
|
| | def simple_repr(obj, attrs: tp.Optional[tp.Sequence[str]] = None, |
| | overrides: dict = {}): |
| | """ |
| | Return a simple representation string for `obj`. |
| | If `attrs` is not None, it should be a list of attributes to include. |
| | """ |
| | params = inspect.signature(obj.__class__).parameters |
| | attrs_repr = [] |
| | if attrs is None: |
| | attrs = list(params.keys()) |
| | for attr in attrs: |
| | display = False |
| | if attr in overrides: |
| | value = overrides[attr] |
| | elif hasattr(obj, attr): |
| | value = getattr(obj, attr) |
| | else: |
| | continue |
| | if attr in params: |
| | param = params[attr] |
| | if param.default is inspect._empty or value != param.default: |
| | display = True |
| | else: |
| | display = True |
| |
|
| | if display: |
| | attrs_repr.append(f"{attr}={value}") |
| | return f"{obj.__class__.__name__}({','.join(attrs_repr)})" |
| |
|
| |
|
| | class MarkdownTable: |
| | """ |
| | Simple MarkdownTable generator. The column titles should be large enough |
| | for the lines content. This will right align everything. |
| | |
| | >>> import io # we use io purely for test purposes, default is sys.stdout. |
| | >>> file = io.StringIO() |
| | >>> table = MarkdownTable(["Item Name", "Price"], file=file) |
| | >>> table.header(); table.line(["Honey", "5"]); table.line(["Car", "5,000"]) |
| | >>> print(file.getvalue().strip()) # Strip for test purposes |
| | | Item Name | Price | |
| | |-----------|-------| |
| | | Honey | 5 | |
| | | Car | 5,000 | |
| | """ |
| | def __init__(self, columns, file=sys.stdout): |
| | self.columns = columns |
| | self.file = file |
| |
|
| | def _writeln(self, line): |
| | self.file.write("|" + "|".join(line) + "|\n") |
| |
|
| | def header(self): |
| | self._writeln(f" {col} " for col in self.columns) |
| | self._writeln("-" * (len(col) + 2) for col in self.columns) |
| |
|
| | def line(self, line): |
| | out = [] |
| | for val, col in zip(line, self.columns): |
| | val = format(val, '>' + str(len(col))) |
| | out.append(" " + val + " ") |
| | self._writeln(out) |
| |
|
| |
|
| | class Chrono: |
| | """ |
| | Measures ellapsed time, calling `torch.cuda.synchronize` if necessary. |
| | `Chrono` instances can be used as context managers (e.g. with `with`). |
| | Upon exit of the block, you can access the duration of the block in seconds |
| | with the `duration` attribute. |
| | |
| | >>> with Chrono() as chrono: |
| | ... _ = sum(range(10_000)) |
| | ... |
| | >>> print(chrono.duration < 10) # Should be true unless on a really slow computer. |
| | True |
| | """ |
| | def __init__(self): |
| | self.duration = None |
| |
|
| | def __enter__(self): |
| | self._begin = time.time() |
| | return self |
| |
|
| | def __exit__(self, exc_type, exc_value, exc_tracebck): |
| | import torch |
| | if torch.cuda.is_available(): |
| | torch.cuda.synchronize() |
| | self.duration = time.time() - self._begin |
| |
|