Spaces:
Sleeping
Sleeping
feat: initial commit
Browse files- config/__init__.py +45 -0
- config/engine.yaml +0 -0
- config/service.dev.yaml +0 -0
- config/service.prd.yaml +0 -0
- pyproject.toml +20 -0
- requirements.txt +3 -1
- src/common/__init__.py +6 -0
- src/common/depth_logging.py +80 -0
- src/common/global_settings.py +57 -0
- src/common/llm.py +50 -0
- src/common/loader.py +25 -0
- src/common/logger.py +270 -0
- src/common/requests.py +145 -0
- src/common/timer.py +52 -0
- src/common/utils.py +80 -0
- src/launch.py +47 -0
config/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration constants for the project."""
|
| 2 |
+
|
| 3 |
+
from os import environ
|
| 4 |
+
from os.path import join, exists, abspath, dirname
|
| 5 |
+
|
| 6 |
+
import yaml
|
| 7 |
+
from easydict import EasyDict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def load_yaml(path: str) -> EasyDict:
|
| 11 |
+
"""Load yaml file."""
|
| 12 |
+
with open(path, "r", encoding="utf8") as f:
|
| 13 |
+
config = yaml.safe_load(f)
|
| 14 |
+
return EasyDict(config)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
##################################################
|
| 18 |
+
# Environments
|
| 19 |
+
##################################################
|
| 20 |
+
DEBUG = False
|
| 21 |
+
ENV = environ.get("ENV", "dev")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
##################################################
|
| 25 |
+
# PATH
|
| 26 |
+
##################################################
|
| 27 |
+
ROOT_DIR = abspath(dirname(dirname(__file__)))
|
| 28 |
+
CONFIG_DIR = join(ROOT_DIR, "config")
|
| 29 |
+
SERVICE_CONFIG_PATH = join(CONFIG_DIR, "service.yaml")
|
| 30 |
+
if not exists(SERVICE_CONFIG_PATH):
|
| 31 |
+
SERVICE_CONFIG_PATH = join(CONFIG_DIR, "service.dev.yaml")
|
| 32 |
+
ENGINE_CONFIG_PATH = join(CONFIG_DIR, "engine.yaml")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
##################################################
|
| 36 |
+
# Configurations
|
| 37 |
+
##################################################
|
| 38 |
+
CFG_SERVICE = load_yaml(SERVICE_CONFIG_PATH)
|
| 39 |
+
CFG_ENGINE = load_yaml(ENGINE_CONFIG_PATH)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
print(ENV)
|
| 44 |
+
print(CFG_SERVICE)
|
| 45 |
+
print(CFG_ENGINE)
|
config/engine.yaml
ADDED
|
File without changes
|
config/service.dev.yaml
ADDED
|
File without changes
|
config/service.prd.yaml
ADDED
|
File without changes
|
pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.poetry]
|
| 2 |
+
name = "bible-guidechat"
|
| 3 |
+
version = "0.1.1"
|
| 4 |
+
description = "Bible Guide Chat"
|
| 5 |
+
authors = ["alchemine <djyoon0223@gmail.com>"]
|
| 6 |
+
readme = "README.md"
|
| 7 |
+
|
| 8 |
+
[tool.poetry.dependencies]
|
| 9 |
+
python = "^3.12"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
[build-system]
|
| 13 |
+
requires = ["poetry-core"]
|
| 14 |
+
build-backend = "poetry.core.masonry.api"
|
| 15 |
+
|
| 16 |
+
[tool.pytest.ini_options]
|
| 17 |
+
filterwarnings = [
|
| 18 |
+
"ignore::DeprecationWarning",
|
| 19 |
+
]
|
| 20 |
+
addopts = "--show-capture=no"
|
requirements.txt
CHANGED
|
@@ -1 +1,3 @@
|
|
| 1 |
-
huggingface_hub==0.25.2
|
|
|
|
|
|
|
|
|
| 1 |
+
huggingface_hub==0.25.2
|
| 2 |
+
easydict==1.13
|
| 3 |
+
pyyaml==6.0.2
|
src/common/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Commonly used package."""
|
| 2 |
+
|
| 3 |
+
from src.common.global_settings import configure_global_settings
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
configure_global_settings()
|
src/common/depth_logging.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DepthLogger class.
|
| 2 |
+
|
| 3 |
+
Log depth of code.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import contextlib
|
| 7 |
+
from collections import defaultdict
|
| 8 |
+
from functools import wraps
|
| 9 |
+
|
| 10 |
+
from src.common.timer import Timer
|
| 11 |
+
from src.common.logger import slog
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DepthManager(contextlib.ContextDecorator):
|
| 15 |
+
"""Code Depth Manager.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
depth (int): Current depth of code
|
| 19 |
+
depths (dict): Depths of codes
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
depth = 1
|
| 23 |
+
depths = defaultdict(lambda: 1)
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def __enter__(cls):
|
| 27 |
+
cls.depth += 1
|
| 28 |
+
return cls
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def __exit__(cls, *exc):
|
| 32 |
+
cls.depths[cls.depth] = 1 # reset
|
| 33 |
+
cls.depth -= 1
|
| 34 |
+
cls.depths[cls.depth] += 1
|
| 35 |
+
return False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def D(fn):
|
| 39 |
+
"""Depth logging decorator.
|
| 40 |
+
|
| 41 |
+
Example:
|
| 42 |
+
>>> @D
|
| 43 |
+
>>> def f():
|
| 44 |
+
... # Depth: 1
|
| 45 |
+
... g()
|
| 46 |
+
>>> @D
|
| 47 |
+
>>> def g():
|
| 48 |
+
... # Depth: 2
|
| 49 |
+
... # do something
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def _print_fn(name, args, fn):
|
| 53 |
+
"""Print depth of code.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
name (str): Function name
|
| 57 |
+
args (tuple): Arguments of the function
|
| 58 |
+
fn (callable): Function
|
| 59 |
+
"""
|
| 60 |
+
logs = f"{' ' + name:15}| "
|
| 61 |
+
if len(args) > 0 and isinstance(
|
| 62 |
+
args[0], object
|
| 63 |
+
): # if function is method or main function
|
| 64 |
+
logs = f"{logs}{fn.__module__.split('.')[-1]}."
|
| 65 |
+
slog(f"{logs}{fn.__name__}()")
|
| 66 |
+
|
| 67 |
+
@wraps(fn)
|
| 68 |
+
def _log(*args, **kwargs):
|
| 69 |
+
name = ".".join(
|
| 70 |
+
[str(DepthManager.depths[d]) for d in range(1, DepthManager.depth + 1)]
|
| 71 |
+
)
|
| 72 |
+
_print_fn(name, args, fn)
|
| 73 |
+
|
| 74 |
+
with DepthManager():
|
| 75 |
+
with Timer(name):
|
| 76 |
+
rst = fn(*args, **kwargs)
|
| 77 |
+
|
| 78 |
+
return rst
|
| 79 |
+
|
| 80 |
+
return _log
|
src/common/global_settings.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Global settings for the project.
|
| 2 |
+
|
| 3 |
+
This module contains global configurations.
|
| 4 |
+
It should be imported and run at the start of the project to ensure consistent settings across all modules.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
# import warnings
|
| 8 |
+
|
| 9 |
+
# import numpy as np
|
| 10 |
+
# import pandas as pd
|
| 11 |
+
# import matplotlib.pyplot as plt
|
| 12 |
+
# from pandas.plotting import register_matplotlib_converters
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def configure_global_settings() -> None:
|
| 16 |
+
"""Configure global settings."""
|
| 17 |
+
pass
|
| 18 |
+
# from langchain_community.cache import InMemoryCache
|
| 19 |
+
# from langchain.globals import set_llm_cache
|
| 20 |
+
|
| 21 |
+
# set_llm_cache(InMemoryCache())
|
| 22 |
+
|
| 23 |
+
# # Warning
|
| 24 |
+
# warnings.filterwarnings("ignore")
|
| 25 |
+
|
| 26 |
+
# # Matplotlib settings
|
| 27 |
+
# register_matplotlib_converters()
|
| 28 |
+
# # plt.rc('font', family='NanumGothic')
|
| 29 |
+
# plt.rc("font", family="DejaVu Sans")
|
| 30 |
+
# plt.rc("axes", unicode_minus=False)
|
| 31 |
+
# plt.rc("font", size=20)
|
| 32 |
+
# plt.rc("figure", titlesize=40, titleweight="bold")
|
| 33 |
+
# plt.style.use("ggplot")
|
| 34 |
+
|
| 35 |
+
# # Numpy settings
|
| 36 |
+
# np.set_printoptions(
|
| 37 |
+
# suppress=True,
|
| 38 |
+
# precision=6,
|
| 39 |
+
# edgeitems=20,
|
| 40 |
+
# linewidth=100,
|
| 41 |
+
# formatter={"float": lambda x: "{:.3f}".format(x)},
|
| 42 |
+
# )
|
| 43 |
+
|
| 44 |
+
# # Pandas settings
|
| 45 |
+
# pd.set_option("display.max_rows", 1000)
|
| 46 |
+
# pd.set_option("display.max_columns", 1000)
|
| 47 |
+
# pd.set_option("display.max_colwidth", 1000)
|
| 48 |
+
# pd.set_option("display.width", 1000)
|
| 49 |
+
# pd.set_option("display.float_format", "{:.2f}".format)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Optional: Add a function to reset to default settings if needed
|
| 53 |
+
# def reset_to_default() -> None:
|
| 54 |
+
# """Reset to default settings."""
|
| 55 |
+
# plt.rcdefaults()
|
| 56 |
+
# np.set_printoptions()
|
| 57 |
+
# pd.reset_option("^display")
|
src/common/llm.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""LLM module.
|
| 2 |
+
|
| 3 |
+
# Proxy LLM
|
| 4 |
+
LiteLLM can be used for LLM inference.
|
| 5 |
+
- Model providers: https://docs.litellm.ai/docs/providers
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from langchain_core.messages import AIMessage
|
| 9 |
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
| 10 |
+
|
| 11 |
+
from config import CFG_ENGINE
|
| 12 |
+
from src.common.logger import log_info
|
| 13 |
+
from src.common.timer import T
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Initialize ChatOpenAI client once at module level
|
| 17 |
+
if CFG_ENGINE.inference.llm.use_proxy:
|
| 18 |
+
llm = ChatOpenAI(
|
| 19 |
+
model=CFG_ENGINE.inference.llm.model,
|
| 20 |
+
base_url=CFG_ENGINE.inference.llm.proxy.base_url,
|
| 21 |
+
)
|
| 22 |
+
embeddings = OpenAIEmbeddings()
|
| 23 |
+
else:
|
| 24 |
+
llm = ChatOpenAI(model=CFG_ENGINE.inference.llm.model)
|
| 25 |
+
embeddings = OpenAIEmbeddings()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@T
|
| 29 |
+
def completion(messages: list[dict | tuple]) -> AIMessage:
|
| 30 |
+
"""Get completion from LLM."""
|
| 31 |
+
response = llm.invoke(messages)
|
| 32 |
+
|
| 33 |
+
# Log metadata
|
| 34 |
+
log_info("Token usage:"), log_info(response.usage_metadata)
|
| 35 |
+
return response
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
async def acompletion(messages: list[dict | tuple]) -> AIMessage:
|
| 39 |
+
"""Get completion from LLM asynchronously."""
|
| 40 |
+
response = await llm.ainvoke(messages)
|
| 41 |
+
|
| 42 |
+
# Log metadata
|
| 43 |
+
log_info("Token usage:"), log_info(response.usage_metadata)
|
| 44 |
+
return response
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
messages = [("user", "안녕! 네 이름이 뭐니?")]
|
| 49 |
+
response = completion(messages)
|
| 50 |
+
print(response.content)
|
src/common/loader.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generic file handling functions."""
|
| 2 |
+
|
| 3 |
+
from glob import glob
|
| 4 |
+
from os.path import isfile, isdir
|
| 5 |
+
|
| 6 |
+
import yaml
|
| 7 |
+
from easydict import EasyDict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
############################################################
|
| 11 |
+
# Lambda functions
|
| 12 |
+
############################################################
|
| 13 |
+
ls_all = lambda path: [path for path in glob(f"{path}/*")]
|
| 14 |
+
ls_dir = lambda path: [path for path in glob(f"{path}/*") if isdir(path)]
|
| 15 |
+
ls_file = lambda path: [path for path in glob(f"{path}/*") if isfile(path)]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
############################################################
|
| 19 |
+
# File loading functions
|
| 20 |
+
############################################################
|
| 21 |
+
def load_yaml(path: str) -> EasyDict:
|
| 22 |
+
"""Load yaml file."""
|
| 23 |
+
with open(path, "r") as f:
|
| 24 |
+
config = yaml.safe_load(f)
|
| 25 |
+
return EasyDict(config)
|
src/common/logger.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging module."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import json
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import logging
|
| 8 |
+
from logging.handlers import TimedRotatingFileHandler
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# https://pkg.go.dev/github.com/shafiqaimanx/pastax/colors
|
| 12 |
+
STYLES = {
|
| 13 |
+
"ENDC": "\033[0m",
|
| 14 |
+
"BOLD": "\033[1m",
|
| 15 |
+
"ITALIC": "\033[3m",
|
| 16 |
+
"UNDERLINE": "\033[4m",
|
| 17 |
+
"RED": "\033[31m",
|
| 18 |
+
"GREEN": "\033[32m",
|
| 19 |
+
"YELLOW": "\033[33m",
|
| 20 |
+
"BLUE": "\033[34m",
|
| 21 |
+
"MAGENTA": "\033[35m",
|
| 22 |
+
"CYAN": "\033[36m",
|
| 23 |
+
"DARKGRAY": "\033[90m",
|
| 24 |
+
"LIGHTRED": "\033[91m",
|
| 25 |
+
"PINK": "\033[95m",
|
| 26 |
+
"FIREBRICK": "\033[38;5;124m",
|
| 27 |
+
"ORANGERED": "\033[38;5;202m",
|
| 28 |
+
"TOMATO": "\033[38;5;203m",
|
| 29 |
+
"GRAPEFRUIT": "\033[38;5;208m",
|
| 30 |
+
"DARKORANGE": "\033[38;5;214m",
|
| 31 |
+
"OKRED": "\033[91m",
|
| 32 |
+
"OKGREEN": "\033[92m",
|
| 33 |
+
"OKYELLOW": "\033[93m",
|
| 34 |
+
"OKBLUE": "\033[94m",
|
| 35 |
+
"OKMAGENTA": "\033[95m",
|
| 36 |
+
"OKCYAN": "\033[96m",
|
| 37 |
+
None: "",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
# With name version (for debugging)
|
| 41 |
+
LOG_FORMAT = "%(asctime)s | %(name)-12s | %(levelname)-8s | %(message)s"
|
| 42 |
+
# LOG_FORMAT = "%(asctime)s | %(levelname)-8s | %(message)s"
|
| 43 |
+
LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S"
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ANSIColorRemovingFormatter(logging.Formatter):
|
| 47 |
+
def format(self, record):
|
| 48 |
+
formatted = super().format(record)
|
| 49 |
+
return re.sub(r"\x1b\[[0-9;]*m", "", formatted)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class TqdmLogger:
|
| 53 |
+
def write(self, message: str):
|
| 54 |
+
log_info(message.lstrip("\r\n"))
|
| 55 |
+
|
| 56 |
+
def flush(self):
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def setup_advanced_logger(
|
| 61 |
+
logger_name: str = None,
|
| 62 |
+
log_level: str = "INFO",
|
| 63 |
+
log_format: str = LOG_FORMAT,
|
| 64 |
+
date_format: str = LOG_DATE_FORMAT,
|
| 65 |
+
log_to_console: bool = True,
|
| 66 |
+
log_to_file: bool = True,
|
| 67 |
+
log_dir: str = "logs",
|
| 68 |
+
file_rotation: str = "midnight",
|
| 69 |
+
file_backup_count: int = 7,
|
| 70 |
+
) -> logging.Logger:
|
| 71 |
+
"""
|
| 72 |
+
Setup an advanced logger with flexible configuration options.
|
| 73 |
+
Keeps colors in console output, removes them in file output.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
logger_name (str): Name of the logger. If None, root logger is used.
|
| 77 |
+
log_level (str): Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
|
| 78 |
+
log_format (str): Format string for log messages.
|
| 79 |
+
date_format (str): Format string for timestamps in log messages.
|
| 80 |
+
log_to_console (bool): Whether to log to console.
|
| 81 |
+
log_to_file (bool): Whether to log to file.
|
| 82 |
+
log_dir (str): Directory to store log files.
|
| 83 |
+
log_file_prefix (str): Prefix for log file names.
|
| 84 |
+
file_rotation (str): When to rotate the log file (e.g., 'midnight', 'h' for hourly).
|
| 85 |
+
file_backup_count (int): Number of backup log files to keep.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
logging.Logger: Configured logger object.
|
| 89 |
+
"""
|
| 90 |
+
# Create logger
|
| 91 |
+
logger = logging.getLogger(logger_name)
|
| 92 |
+
logger.setLevel(getattr(logging, log_level.upper()))
|
| 93 |
+
|
| 94 |
+
# 기존 핸들러 제거
|
| 95 |
+
if logger.hasHandlers():
|
| 96 |
+
logger.handlers.clear()
|
| 97 |
+
|
| 98 |
+
# Create formatters
|
| 99 |
+
color_formatter = logging.Formatter(log_format, datefmt=date_format)
|
| 100 |
+
no_color_formatter = ANSIColorRemovingFormatter(log_format, datefmt=date_format)
|
| 101 |
+
|
| 102 |
+
# Console handler (with colors)
|
| 103 |
+
if log_to_console:
|
| 104 |
+
console_handler = logging.StreamHandler()
|
| 105 |
+
console_handler.setFormatter(color_formatter)
|
| 106 |
+
logger.addHandler(console_handler)
|
| 107 |
+
|
| 108 |
+
# File handler (without colors)
|
| 109 |
+
if log_to_file:
|
| 110 |
+
log_dir_path = Path(log_dir)
|
| 111 |
+
log_dir_path.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
|
| 113 |
+
current_date = datetime.now().strftime("%Y-%m-%d")
|
| 114 |
+
log_file_name = f"{current_date}.log"
|
| 115 |
+
log_file_path = log_dir_path / log_file_name
|
| 116 |
+
|
| 117 |
+
file_handler = TimedRotatingFileHandler(
|
| 118 |
+
filename=log_file_path,
|
| 119 |
+
when=file_rotation,
|
| 120 |
+
backupCount=file_backup_count,
|
| 121 |
+
)
|
| 122 |
+
file_handler.setFormatter(no_color_formatter)
|
| 123 |
+
file_handler.flush = lambda: file_handler.stream.flush()
|
| 124 |
+
logger.addHandler(file_handler)
|
| 125 |
+
|
| 126 |
+
return logger
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def pretty_dict(s: str) -> str:
|
| 130 |
+
"""Pretty print dictionary.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
s (str): The dictionary to pretty print.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
str: The pretty printed dictionary.
|
| 137 |
+
"""
|
| 138 |
+
json_str = json.dumps(s, indent=2, ensure_ascii=False)
|
| 139 |
+
json_str = json_str.replace('\\"', "'")
|
| 140 |
+
return json_str
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def slog(
|
| 144 |
+
msg: str, style: str | None = None, level: str = "info", dump: bool = True, **kwargs
|
| 145 |
+
) -> str:
|
| 146 |
+
"""Stylish log message.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
msg (str): The message to log.
|
| 150 |
+
style (str): The style of the message.
|
| 151 |
+
level (str): The log level.
|
| 152 |
+
dump (bool): The dump flag.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
str: The stylish message.
|
| 156 |
+
"""
|
| 157 |
+
try:
|
| 158 |
+
if dump:
|
| 159 |
+
msg = pretty_dict(msg)
|
| 160 |
+
msg = msg.strip('"') # remove redundant quotes
|
| 161 |
+
except:
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
stylish_msg = f"{STYLES['BOLD']}{STYLES[style]}{msg}{STYLES['ENDC']}"
|
| 165 |
+
match level:
|
| 166 |
+
case "info":
|
| 167 |
+
logger.info(stylish_msg, **kwargs)
|
| 168 |
+
case "error":
|
| 169 |
+
logger.error(stylish_msg, **kwargs)
|
| 170 |
+
case "warning":
|
| 171 |
+
logger.warning(stylish_msg, **kwargs)
|
| 172 |
+
case "debug":
|
| 173 |
+
logger.debug(stylish_msg, **kwargs)
|
| 174 |
+
case _:
|
| 175 |
+
print(stylish_msg, **kwargs)
|
| 176 |
+
|
| 177 |
+
return stylish_msg
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def log_info(msg: str, dump: bool = True, **kwargs) -> str:
|
| 181 |
+
"""Stylish info log.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
msg (str): The message to log.
|
| 185 |
+
dump (bool): The dump flag. Defaults to True.
|
| 186 |
+
"""
|
| 187 |
+
return slog(msg, style="GREEN", dump=dump, **kwargs)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def log_success(msg: str, dump: bool = True, prefix: bool = True, **kwargs) -> str:
|
| 191 |
+
"""Stylish success log.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
msg (str): The message to log.
|
| 195 |
+
dump (bool): The dump flag. Defaults to True.
|
| 196 |
+
prefix (bool): The prefix flag. Defaults to True.
|
| 197 |
+
"""
|
| 198 |
+
if prefix:
|
| 199 |
+
msg = f"[SUCCESS] {msg}"
|
| 200 |
+
return slog(msg, style="OKBLUE", dump=dump, **kwargs)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def log_error(
|
| 204 |
+
msg: str,
|
| 205 |
+
dump: bool = False,
|
| 206 |
+
prefix: bool = True,
|
| 207 |
+
exc_info: Exception | None = None,
|
| 208 |
+
**kwargs,
|
| 209 |
+
) -> str:
|
| 210 |
+
"""Stylish error log.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
msg (str): The message to log.
|
| 214 |
+
dump (bool): The dump flag. Defaults to True.
|
| 215 |
+
"""
|
| 216 |
+
if prefix:
|
| 217 |
+
msg = f"[FAILED] {msg}"
|
| 218 |
+
return slog(
|
| 219 |
+
msg, style="TOMATO", level="error", dump=dump, exc_info=exc_info, **kwargs
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def log_warning(msg: str, dump: bool = False, prefix: bool = True, **kwargs) -> str:
|
| 224 |
+
"""Stylish warning log.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
msg (str): The message to log.
|
| 228 |
+
dump (bool): The dump flag. Defaults to True.
|
| 229 |
+
"""
|
| 230 |
+
if prefix:
|
| 231 |
+
msg = f"[WARNING] {msg}"
|
| 232 |
+
return slog(msg, style="GRAPEFRUIT", level="warning", dump=dump, **kwargs)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def log_api(msg: str, error: bool = False, **kwargs) -> None:
|
| 236 |
+
"""Stylish api log.
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
msg (str): The message to log.
|
| 240 |
+
error (bool): The error status of the API. Defaults to False.
|
| 241 |
+
"""
|
| 242 |
+
if error:
|
| 243 |
+
log_error("Request API:")
|
| 244 |
+
log_error(msg, dump=True)
|
| 245 |
+
else:
|
| 246 |
+
log_success("Request API:")
|
| 247 |
+
log_success(msg)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
# Setup default logger
|
| 251 |
+
logger = setup_advanced_logger()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# Disable logging for specific modules
|
| 255 |
+
for name in ("elastic_transport.transport", "urllib3.connectionpool", "httpx"):
|
| 256 |
+
_logger = logging.getLogger(name)
|
| 257 |
+
_logger.setLevel(logging.ERROR)
|
| 258 |
+
|
| 259 |
+
# Get tqdm file
|
| 260 |
+
tqdm_file = TqdmLogger()
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
if __name__ == "__main__":
|
| 264 |
+
log_info("This is an info message.")
|
| 265 |
+
log_success("This is a success message.")
|
| 266 |
+
log_error("This is an error message.")
|
| 267 |
+
log_warning("This is a warning message.")
|
| 268 |
+
log_api("This is an API message.")
|
| 269 |
+
for style in STYLES:
|
| 270 |
+
slog(f"This is a {style} message.", style=style)
|
src/common/requests.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Requests module for handling HTTP requests."""
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import aiohttp
|
| 7 |
+
import requests
|
| 8 |
+
from requests import Response, RequestException
|
| 9 |
+
|
| 10 |
+
from src.common.logger import log_api
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
DEFAULT_HEADERS = {"accept": "application/json", "Content-Type": "application/json"}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class APIError(Exception):
|
| 17 |
+
"""API Error exception.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
url (str): The URL of the API.
|
| 21 |
+
headers (dict): The headers of the API.
|
| 22 |
+
json (dict): The JSON data of the API.
|
| 23 |
+
response (Response): The response of the API.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
str: The API error message.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, url: str, headers: dict, json: dict, response: Response):
|
| 30 |
+
self.url = url
|
| 31 |
+
self.headers = headers
|
| 32 |
+
self.json = json
|
| 33 |
+
self.response = response
|
| 34 |
+
|
| 35 |
+
def __str__(self):
|
| 36 |
+
return f"""APIError: <Response [{self.response.status_code}]>
|
| 37 |
+
requests.post(
|
| 38 |
+
url="{self.url}",
|
| 39 |
+
headers={self.headers},
|
| 40 |
+
json={self.json}
|
| 41 |
+
) -> {self.response.json()}"""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_request_log(
|
| 45 |
+
url: str, headers: dict, json: dict, response: Response | None = None
|
| 46 |
+
) -> dict:
|
| 47 |
+
"""Get the request log.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
url (str): The URL of the API.
|
| 51 |
+
headers (dict): The headers of the API.
|
| 52 |
+
json (dict): The JSON data of the API.
|
| 53 |
+
response (Response): The response of the API.
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
dict: The request log.
|
| 57 |
+
"""
|
| 58 |
+
log = dict(
|
| 59 |
+
url=url,
|
| 60 |
+
headers=headers,
|
| 61 |
+
json=json,
|
| 62 |
+
reproduction_code=f"import requests; requests.post(url='{url}', headers={headers}, json={json})",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if response:
|
| 66 |
+
log.update(
|
| 67 |
+
response=dict(status_code=response.status_code, json=response.json())
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return log
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def safe_request(url: str, json: dict, headers: dict = DEFAULT_HEADERS) -> dict:
|
| 74 |
+
"""Requests with validation.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
url (str): The URL of the API.
|
| 78 |
+
json (dict): The JSON data of the API.
|
| 79 |
+
headers (dict): The headers of the API.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
Response: The response of the API.
|
| 83 |
+
"""
|
| 84 |
+
# Check the API communication validness
|
| 85 |
+
try:
|
| 86 |
+
response = requests.post(url=url, headers=headers, json=json)
|
| 87 |
+
response.raise_for_status()
|
| 88 |
+
log = get_request_log(url, headers, json)
|
| 89 |
+
log_api(log)
|
| 90 |
+
return response.json()
|
| 91 |
+
except RequestException as e:
|
| 92 |
+
log = get_request_log(url, headers, json)
|
| 93 |
+
log_api(log, error=True)
|
| 94 |
+
raise APIError(url, headers, json, response)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
async def async_safe_request(
|
| 98 |
+
session: aiohttp.ClientSession,
|
| 99 |
+
url: str,
|
| 100 |
+
data: dict,
|
| 101 |
+
headers: dict = DEFAULT_HEADERS,
|
| 102 |
+
) -> list | dict:
|
| 103 |
+
"""Post request using aiohttp.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
session (aiohttp.ClientSession): aiohttp session.
|
| 107 |
+
url (str): URL to post.
|
| 108 |
+
data (dict): Data to post.
|
| 109 |
+
headers (dict): The headers of the API.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
list | dict: Response data.
|
| 113 |
+
"""
|
| 114 |
+
async with session.post(
|
| 115 |
+
url=url,
|
| 116 |
+
headers=headers,
|
| 117 |
+
json=data,
|
| 118 |
+
) as response:
|
| 119 |
+
response.raise_for_status()
|
| 120 |
+
return await response.json()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
async def async_safe_requests(batch: list[dict]) -> list[Any]:
|
| 124 |
+
"""Post requests asynchronously.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
batch (list[dict]): List of data to post.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
list[Any]: List of response data.
|
| 131 |
+
|
| 132 |
+
Examples:
|
| 133 |
+
import asyncio
|
| 134 |
+
asyncio.run(async_safe_requests(batch))
|
| 135 |
+
"""
|
| 136 |
+
async with aiohttp.ClientSession() as session:
|
| 137 |
+
futures = [async_safe_request(session, **input) for input in batch]
|
| 138 |
+
responses = await asyncio.gather(*futures)
|
| 139 |
+
return responses
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
url = "https://httpbin.org/post"
|
| 144 |
+
json = {"key": "value"}
|
| 145 |
+
response = safe_request(url, json)
|
src/common/timer.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Timer class.
|
| 2 |
+
|
| 3 |
+
Context and decorator form timer.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import contextlib
|
| 7 |
+
from functools import wraps
|
| 8 |
+
from time import perf_counter
|
| 9 |
+
|
| 10 |
+
from src.common.logger import log_success, log_info
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Timer(contextlib.ContextDecorator):
|
| 14 |
+
"""Timer.
|
| 15 |
+
|
| 16 |
+
Examples:
|
| 17 |
+
>>> with Timer('Code1'):
|
| 18 |
+
... sleep(1)
|
| 19 |
+
* Code1 | 1.00s (0.02m)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, name="Elapsed time"):
|
| 23 |
+
self.name = name
|
| 24 |
+
|
| 25 |
+
def __enter__(self):
|
| 26 |
+
log_info(f"{'[START] ' + self.name:15}")
|
| 27 |
+
self.start_time = perf_counter()
|
| 28 |
+
return self
|
| 29 |
+
|
| 30 |
+
def __exit__(self, *exc):
|
| 31 |
+
elapsed_time = perf_counter() - self.start_time
|
| 32 |
+
log_success(f"{self.name:15} | {elapsed_time:.2f}s ({elapsed_time/60:.2f}m)")
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def T(fn: callable) -> callable:
|
| 37 |
+
"""Timer decorator.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
>>> @T
|
| 41 |
+
>>> def f():
|
| 42 |
+
... sleep(1)
|
| 43 |
+
* Elapsed time | 1.00s (0.02m)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
@wraps(fn)
|
| 47 |
+
def _log(*args, **kwargs):
|
| 48 |
+
with Timer(fn.__name__):
|
| 49 |
+
rst = fn(*args, **kwargs)
|
| 50 |
+
return rst
|
| 51 |
+
|
| 52 |
+
return _log
|
src/common/utils.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility module.
|
| 2 |
+
|
| 3 |
+
Commonly used functions and classes are here.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
from src.common.logger import log_info
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
vars_ = lambda obj: {k: v for k, v in vars(obj).items() if not k.startswith("__")}
|
| 12 |
+
str2dt = lambda s, format="%Y-%m-%d": datetime.strptime(s, format)
|
| 13 |
+
dt2str = lambda dt, format="%Y-%m-%d": dt.strftime(format)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def lmap(fn: callable, arr: list, scheduler: str | None = None) -> list:
|
| 17 |
+
"""List map.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
fn (callable): Function to apply
|
| 21 |
+
arr (list): List to apply function
|
| 22 |
+
scheduler (str, optional): Dask scheduler. Defaults to None.
|
| 23 |
+
- None | "single-threaded": Single-threaded
|
| 24 |
+
- "threads": Multi-threaded
|
| 25 |
+
- "processes": Multi-process
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
list: List of results
|
| 29 |
+
"""
|
| 30 |
+
if scheduler is None:
|
| 31 |
+
return list(map(fn, arr))
|
| 32 |
+
else:
|
| 33 |
+
from dask import delayed, compute
|
| 34 |
+
|
| 35 |
+
assert scheduler in [
|
| 36 |
+
"single-threaded",
|
| 37 |
+
"threads",
|
| 38 |
+
"processes",
|
| 39 |
+
], f"Invalid scheduler: {scheduler}"
|
| 40 |
+
tasks = (delayed(fn)(e) for e in arr)
|
| 41 |
+
return list(compute(*tasks, scheduler=scheduler))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def tprint(dic: dict) -> None:
|
| 45 |
+
"""Table print."""
|
| 46 |
+
import tabulate
|
| 47 |
+
|
| 48 |
+
# print with fancy 'psql' format
|
| 49 |
+
log_info(tabulate(dic, headers="keys", tablefmt="psql"))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def str2bool(s: str | bool) -> bool:
|
| 53 |
+
"""String to boolean."""
|
| 54 |
+
if isinstance(s, bool):
|
| 55 |
+
return s
|
| 56 |
+
if s.lower() in ("yes", "true", "t", "y", "1"):
|
| 57 |
+
return True
|
| 58 |
+
elif s.lower() in ("no", "false", "f", "n", "0"):
|
| 59 |
+
return False
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError(f"Invalid input: {s} (type: {type(s)})")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class MetaSingleton(type):
|
| 65 |
+
"""Meta singleton.
|
| 66 |
+
|
| 67 |
+
Example:
|
| 68 |
+
>>> class A(metaclass=MetaSingleton):
|
| 69 |
+
... pass
|
| 70 |
+
>>> a1 = A()
|
| 71 |
+
>>> a2 = A()
|
| 72 |
+
>>> assert a1 is a2
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
_instances = {}
|
| 76 |
+
|
| 77 |
+
def __call__(cls, *args, **kwargs):
|
| 78 |
+
if cls not in cls._instances:
|
| 79 |
+
cls._instances[cls] = super(MetaSingleton, cls).__call__(*args, **kwargs)
|
| 80 |
+
return cls._instances[cls]
|
src/launch.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from time import sleep
|
| 2 |
+
|
| 3 |
+
from src.common.timer import T
|
| 4 |
+
from src.common.depth_logging import D
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@D
|
| 8 |
+
def main():
|
| 9 |
+
|
| 10 |
+
main1()
|
| 11 |
+
main2()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@D
|
| 15 |
+
def main1():
|
| 16 |
+
main11()
|
| 17 |
+
main12()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@D
|
| 21 |
+
def main11():
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@D
|
| 26 |
+
def main12():
|
| 27 |
+
return
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@D
|
| 31 |
+
def main2():
|
| 32 |
+
main21()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@D
|
| 36 |
+
def main21():
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@T
|
| 41 |
+
def f():
|
| 42 |
+
sleep(1)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
| 47 |
+
f()
|