FBAGSTM's picture
STM32 AI Experimentation Hub
747451d
###################################################################################
# Copyright (c) 2021, 2024 STMicroelectronics.
# All rights reserved.
# This software is licensed under terms that can be found in the LICENSE file in
# the root directory of this software component.
# If no LICENSE file comes with this software, it is provided AS-IS.
###################################################################################
"""
ST AI runner - Utilities
"""
import sys
import logging
import re
import copy
from io import StringIO
from typing import Union, List
from colorama import init, Fore, Style
_LOGGER_NAME_ = 'STAI-RUNNER'
STMAI_RUNNER_LOGGER_NAME = _LOGGER_NAME_
def escape_ansi(line):
"""Remove ANSI character from string"""
ansi_escape = re.compile(r'(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]')
return ansi_escape.sub('', str(line))
class IndentFormatter(logging.Formatter):
"""Default Log Formatter"""
indent = 0
inc = 0
def __init__(self, format_str="%(levelname)s %(message)s"):
super().__init__(fmt=format_str)
def enable_inc(self, enable: bool = True):
"""."""
self.inc = 3 if enable else 0
def is_target_msg(self, msg):
"""."""
return msg and msg.strip().startswith('[TARGET:')
def indent_msg(self, msg):
"""."""
inc = 0
clean_msg_ = escape_ansi(msg)
if clean_msg_.startswith('->'):
inc = self.inc
if clean_msg_.startswith('<-'):
self.indent -= self.inc
self.indent = 0 if self.indent < 0 else self.indent
if self.indent > 0:
msg = ' ' * self.indent + msg
self.indent += inc
return msg
class ColorFormatter(IndentFormatter):
"""Color Formatter"""
COLORS = {
"W": Fore.LIGHTYELLOW_EX, # WARNING
"E": Fore.RED + Style.BRIGHT, # ERROR
"D": Fore.CYAN, # DEBUG
"I": Fore.GREEN, # INFO
"C": Fore.RED, # CRITICAL
"T": Fore.MAGENTA # TARGET.INFO
}
def __init__(self, with_prefix=False):
self.with_prefix = with_prefix
super().__init__()
def format(self, record):
"""Update/format the specified record"""
if hasattr(record, 'original_levelname'):
lvl_name = record.original_levelname
target_msg = record.target_msg
msg = self.indent_msg(record.original_msg)
else:
lvl_name = record.levelname
target_msg = self.is_target_msg(str(record.msg))
msg = self.indent_msg(str(record.msg))
record.msg = msg
slvl_name_ = lvl_name[0]
slvl_name_ = 'T' if (target_msg and slvl_name_ == 'I') else slvl_name_
cls_name_ = (':' + record.name) if self.with_prefix else ''
color = self.COLORS.get(slvl_name_, '')
add_prefix = lvl_name != "INFO" or self.with_prefix
if color:
if slvl_name_ == 'T' or add_prefix:
record.levelname = color + '[' + slvl_name_ + cls_name_ + ']' + Style.RESET_ALL
record.msg = color + record.msg + Style.RESET_ALL
else:
record.levelname = ''
return super().format(record)
class DefaultFormatter(IndentFormatter):
"""Default Formatter"""
def __init__(self, with_prefix=False):
self.with_prefix = with_prefix
super().__init__()
def format(self, record):
"""Update/format the specified record"""
if hasattr(record, 'original_levelname'):
lvl_name = record.original_levelname
target_msg = record.target_msg
msg = self.indent_msg(record.original_msg)
else:
lvl_name = record.levelname
target_msg = self.is_target_msg(str(record.msg))
msg = self.indent_msg(str(record.msg))
record.msg = escape_ansi(msg) # remove ANSI characters
cls_name_ = (':' + record.name) if self.with_prefix else ''
slvl_name_ = 'T' if target_msg else lvl_name[0]
add_prefix = lvl_name != "INFO" or self.with_prefix
if slvl_name_ == 'T' or add_prefix:
record.levelname = '[' + slvl_name_ + cls_name_ + ']'
else:
record.levelname = ''
return super().format(record)
class FileFormatter(IndentFormatter):
"""File Formatter"""
def __init__(self, with_prefix=False):
"""Constructor"""
self.with_prefix = with_prefix
super().__init__()
def format(self, record):
"""Update/format the specified record"""
record.original_levelname = record.levelname
record.target_msg = self.is_target_msg(str(record.msg))
record.original_msg = self.indent_msg(str(record.msg))
record.msg = escape_ansi(record.original_msg)
cls_name_ = (':' + record.name) if self.with_prefix else ''
slvl_name_ = 'T' if record.target_msg else record.levelname[0]
record.levelname = '[' + slvl_name_ + cls_name_ + ']'
return super().format(record)
class DefaultStrHandler(logging.StreamHandler):
"""Default handler"""
def __init__(self, stream=None):
"""."""
stream = sys.stdout if stream is None else stream
super(DefaultStrHandler, self).__init__(stream)
def get_logger(name=None, level=logging.WARNING, color=True, filename=None, with_prefix=True):
"""Utility function to create a logger object"""
logger = logging.getLogger(name)
if logger.handlers and len(logger.handlers) == 2 and isinstance(logger.handlers[1], DefaultStrHandler):
# w/ log file
return logger
elif logger.handlers and len(logger.handlers) == 1 and isinstance(logger.handlers[0], DefaultStrHandler):
# wo/ log file
return logger
if color:
init()
if filename is not None:
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler(filename, mode='w', encoding='utf-8')
fh.setLevel(logging.DEBUG)
fh_formatter = FileFormatter(with_prefix=with_prefix)
fh_formatter.enable_inc()
fh.setFormatter(fh_formatter)
logger.addHandler(fh)
else:
logger.setLevel(level)
console = DefaultStrHandler()
console.setLevel(level)
if color:
formatter = ColorFormatter(with_prefix=with_prefix)
else:
formatter = DefaultFormatter(with_prefix=with_prefix)
formatter.enable_inc(filename is None)
console.setFormatter(formatter)
logger.addHandler(console)
logger.propagate = False
return logger
def get_log_level(logger: logging.Logger):
"""Return effective log level for console"""
if logger.handlers and len(logger.handlers) == 2 and isinstance(logger.handlers[1], DefaultStrHandler):
# w/ log file
return logger.handlers[1].level
elif logger.handlers and len(logger.handlers) == 1 and isinstance(logger.handlers[0], DefaultStrHandler):
# wo/ log file
return logger.handlers[0].level
return logger.getEffectiveLevel()
def set_log_level(level: Union[str, int] = logging.DEBUG, logger: Union[logging.Logger, None] = None):
"""Set the log level of the module""" # noqa: DAR101,DAR201,DAR401
if isinstance(level, str):
level = level.upper()
level = logging.getLevelName(level)
if logger is None:
logger = get_logger(_LOGGER_NAME_)
logger.setLevel(level)
if logger.handlers:
logger.handlers[0].setLevel(level)
class TableWriter(StringIO):
"""Pretty-print tabular data (table form)"""
N_SPACE = 2
def __init__(self, indent: int = 0, csep: str = ' '):
"""Create the Table instance""" # noqa: DAR101,DAR201
self._header = [] # type: List[str]
self._notes = [] # type: List[str]
self._datas = [] # type: List[Union[List[str], str]]
self._title = '' # type: str
self._fmt = '' # type: str
self._sizes = [] # type: List[int]
self._indent = int(max(indent, 0))
self._csep = csep
super(TableWriter, self).__init__()
def set_header(self, items: Union[List[str], str]):
"""Set the name of the columns""" # noqa: DAR101,DAR201
items = self._update_sizes(items)
self._header = items
def set_title(self, title: str):
"""Set the title (optional)""" # noqa: DAR101,DAR201
self._title = title
def set_fmt(self, fmt: str):
"""Set format description (optional)""" # noqa: DAR101,DAR201
self._fmt = fmt
def add_note(self, note: str):
"""Add a note (footer position)""" # noqa: DAR101,DAR201
self._notes.append(note)
def add_row(self, items: Union[List[str], str]):
"""Add a row (list of item)""" # noqa: DAR101,DAR201
items = self._update_sizes(items)
self._datas.append(items)
def add_separator(self, value: str = '-'):
"""Add a separtor (line)""" # noqa: DAR101,DAR201
self._datas.append(value)
def _update_sizes(self, items: Union[List[str], str]) -> List[str]:
"""Update the column sizes""" # noqa: DAR101,DAR201,DAR401
items = [items] if isinstance(items, str) else items
if not self._sizes:
self._sizes = [len(str(item)) + TableWriter.N_SPACE for item in items]
else:
if len(items) > len(self._sizes):
raise ValueError('Size of the provided row is invalid')
for i, item in enumerate(items):
self._sizes[i] = max(len(str(item)) + TableWriter.N_SPACE, self._sizes[i])
return items
def _write_row(self, items: List[str], fmt):
"""Create a formated row""" # noqa: DAR101,DAR201
nfmt = ['.'] * len(self._sizes)
for i, val in enumerate(fmt):
if i < len(nfmt):
nfmt[i] = val
row = ''
for i, item in enumerate(items):
sup = self._sizes[i] - len(str(item))
if nfmt[i] == '>':
row += ' ' * sup + str(item) + ' ' * len(self._csep)
else:
row += str(item) + ' ' * sup + ' ' * len(self._csep)
self.write(row)
def _write_separator(self, val: str):
"""Create a formatted separator""" # noqa: DAR101,DAR201
row = ''
for size in self._sizes:
row += val * size + self._csep
self.write(row)
def write(self, msg: str, newline: str = '\n'):
"""Write fct""" # noqa: DAR101,DAR201
super(TableWriter, self).write(' ' * self._indent + msg + newline)
def getvalue(self, fmt: str = '', endline: bool = False):
"""Buid and return the formatted table""" # noqa: DAR101,DAR201
fmt = fmt if fmt else self._fmt
self.write('')
if self._title:
self.write(self._title)
self._write_separator('-')
if self._header:
self._write_row(self._header, fmt)
self._write_separator('-')
for data in self._datas:
if isinstance(data, str):
self._write_separator(data)
else:
self._write_row(data, fmt)
if endline or self._notes:
self._write_separator('-')
for note in self._notes:
self.write(note)
buff = super(TableWriter, self).getvalue()
return buff
def __str__(self):
return self.getvalue()
def truncate_name(name: str, maxlen: int = 30):
"""Return a truncated string""" # noqa: DAR101, DAR201
maxlen = max(maxlen, 4)
l_, r_ = (3, 1) if maxlen <= 12 else (12, 10)
return (name[:maxlen - l_] + ".." + name[-r_:]) if maxlen < len(name) else name