Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import logging | |
| import os | |
| from termcolor import colored | |
| class PythonLogger: | |
| """Simple console logger for DL training | |
| This is a WIP | |
| """ | |
| def __init__(self, name: str = "launch"): | |
| self.logger = logging.getLogger(name) | |
| def file_logging(self, file_name: str = "launch.log"): | |
| """Log to file""" | |
| if os.path.exists(file_name): | |
| try: | |
| os.remove(file_name) | |
| except FileNotFoundError: | |
| # ignore if already removed (can happen with multiple processes) | |
| pass | |
| formatter = logging.Formatter( | |
| "[%(asctime)s - %(name)s - %(levelname)s] %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| filehandler = logging.FileHandler(file_name) | |
| filehandler.setFormatter(formatter) | |
| filehandler.setLevel(logging.DEBUG) | |
| self.logger.addHandler(filehandler) | |
| def log(self, message: str): | |
| """Log message""" | |
| self.logger.info(message) | |
| def info(self, message: str): | |
| """Log info""" | |
| self.logger.info(colored(message, "light_blue")) | |
| def success(self, message: str): | |
| """Log success""" | |
| self.logger.info(colored(message, "light_green")) | |
| def warning(self, message: str): | |
| """Log warning""" | |
| self.logger.warning(colored(message, "light_yellow")) | |
| def error(self, message: str): | |
| """Log error""" | |
| self.logger.error(colored(message, "light_red")) | |
| class RankZeroLoggingWrapper: | |
| """Wrapper class to only log from rank 0 process in distributed training.""" | |
| def __init__(self, obj, dist): | |
| self.obj = obj | |
| self.dist = dist | |
| def __getattr__(self, name): | |
| attr = getattr(self.obj, name) | |
| if callable(attr): | |
| def wrapper(*args, **kwargs): | |
| if self.dist.rank == 0: | |
| return attr(*args, **kwargs) | |
| else: | |
| return None | |
| return wrapper | |
| else: | |
| return attr | |