File size: 5,156 Bytes
5b6fc4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
"""
OmniCoreX Utilities Module
Helper functions, logging setup, configuration parsing,
and common utilities used throughout the OmniCoreX system.
Features:
- Robust logging setup with configurable formats and levels.
- Configuration loader supporting YAML and JSON with overrides.
- Seed setting for reproducibility.
- Timing and benchmarking decorators.
- Various small utilities for system use.
"""
import os
import sys
import yaml
import json
import logging
import random
import time
import numpy as np
import torch
# ----------------------- Logging Setup ----------------------- #
def setup_logging(log_level=logging.INFO, log_file: str = None) -> logging.Logger:
"""
Sets up a logger with console and optional file handlers.
Args:
log_level: Logging level (e.g., logging.INFO).
log_file: Optional path to log file.
Returns:
Configured logger instance.
"""
logger = logging.getLogger("OmniCoreX")
logger.setLevel(log_level)
formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s")
# Remove existing handlers
for handler in logger.handlers[:]:
logger.removeHandler(handler)
# Console handler
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(log_level)
ch.setFormatter(formatter)
logger.addHandler(ch)
# File handler if specified
if log_file:
fh = logging.FileHandler(log_file)
fh.setLevel(log_level)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
# Global logger instance
logger = setup_logging()
# ----------------------- Configuration Loading ----------------------- #
def load_config_file(config_path: str) -> dict:
"""
Loads a YAML or JSON configuration file.
Args:
config_path: Path to the config file.
Returns:
Dictionary of configuration parameters.
"""
if not os.path.isfile(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
ext = os.path.splitext(config_path)[1].lower()
with open(config_path, "r", encoding="utf-8") as f:
if ext in [".yaml", ".yml"]:
cfg = yaml.safe_load(f)
elif ext == ".json":
cfg = json.load(f)
else:
raise ValueError(f"Unsupported config format: {ext}")
return cfg
def merge_dicts(base: dict, override: dict) -> dict:
"""
Deep merges two dictionaries, with the override taking precedence.
Args:
base: Base dictionary.
override: Dictionary with override values.
Returns:
Merged dictionary.
"""
result = base.copy()
for k, v in override.items():
if k in result and isinstance(result[k], dict) and isinstance(v, dict):
result[k] = merge_dicts(result[k], v)
else:
result[k] = v
return result
# ----------------------- Seed Setting ----------------------- #
def set_seed(seed: int = 42):
"""
Set seed for reproducibility across random, numpy and torch.
Args:
seed: Integer seed value.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
logger.info(f"Random seed set to {seed}")
# ----------------------- Timing Utilities ----------------------- #
def timeit(func):
"""
Decorator to measure and log function execution time.
Usage:
@timeit
def my_function(...):
...
"""
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
end = time.time()
logger.info(f"Function {func.__name__!r} executed in {(end - start):.4f}s")
return result
return wrapper
# ----------------------- Other Utility Functions ----------------------- #
def ensure_dir(dirname: str):
"""
Creates directory if it does not exist.
Args:
dirname: Directory path to create.
"""
if not os.path.exists(dirname):
os.makedirs(dirname)
logger.debug(f"Directory created: {dirname}")
def to_device(batch: dict, device: torch.device) -> dict:
"""
Moves all tensor elements in batch dict to specified device.
Args:
batch: Dictionary with tensors.
device: Target torch device.
Returns:
Dictionary with tensors on device.
"""
return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
if __name__ == "__main__":
# Demo usage of utilities
set_seed(1234)
logger.info("This is a test log message.")
# Create dummy config files and test merging
base_cfg = {"model": {"layers": 12, "embed_dim": 256}, "training": {"batch_size": 32}}
override_cfg = {"model": {"layers": 24}, "training": {"learning_rate": 0.001}}
merged_cfg = merge_dicts(base_cfg, override_cfg)
logger.info(f"Merged config: {merged_cfg}")
# Test directory creation
test_dir = "./tmp_test_dir"
ensure_dir(test_dir)
# Test timing decorator
@timeit
def dummy_work():
import time; time.sleep(0.5)
dummy_work()
|