Spaces:
Running
on
Zero
Running
on
Zero
| import logging | |
| import os | |
| from functools import wraps | |
| import torch | |
| def get_logger(name): | |
| logger = logging.getLogger(name) | |
| logger.setLevel(logging.INFO) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(logging.INFO) | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| return logger | |
| logger = get_logger('hy3dgen.shapgen') | |
| class synchronize_timer: | |
| """ Synchronized timer to count the inference time of `nn.Module.forward`. | |
| Supports both context manager and decorator usage. | |
| Example as context manager: | |
| ```python | |
| with synchronize_timer('name') as t: | |
| run() | |
| ``` | |
| Example as decorator: | |
| ```python | |
| @synchronize_timer('Export to trimesh') | |
| def export_to_trimesh(mesh_output): | |
| pass | |
| ``` | |
| """ | |
| def __init__(self, name=None): | |
| self.name = name | |
| def __enter__(self): | |
| """Context manager entry: start timing.""" | |
| if os.environ.get('HY3DGEN_DEBUG', '0') == '1': | |
| self.start = torch.cuda.Event(enable_timing=True) | |
| self.end = torch.cuda.Event(enable_timing=True) | |
| self.start.record() | |
| return lambda: self.time | |
| def __exit__(self, exc_type, exc_value, exc_tb): | |
| """Context manager exit: stop timing and log results.""" | |
| if os.environ.get('HY3DGEN_DEBUG', '0') == '1': | |
| self.end.record() | |
| torch.cuda.synchronize() | |
| self.time = self.start.elapsed_time(self.end) | |
| if self.name is not None: | |
| logger.info(f'{self.name} takes {self.time} ms') | |
| def __call__(self, func): | |
| """Decorator: wrap the function to time its execution.""" | |
| def wrapper(*args, **kwargs): | |
| with self: | |
| result = func(*args, **kwargs) | |
| return result | |
| return wrapper | |