Kosasih commited on
Commit
5b6fc4e
·
verified ·
1 Parent(s): 0c2d7d2

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +193 -0
utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OmniCoreX Utilities Module
3
+
4
+ Helper functions, logging setup, configuration parsing,
5
+ and common utilities used throughout the OmniCoreX system.
6
+
7
+ Features:
8
+ - Robust logging setup with configurable formats and levels.
9
+ - Configuration loader supporting YAML and JSON with overrides.
10
+ - Seed setting for reproducibility.
11
+ - Timing and benchmarking decorators.
12
+ - Various small utilities for system use.
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import yaml
18
+ import json
19
+ import logging
20
+ import random
21
+ import time
22
+ import numpy as np
23
+ import torch
24
+
25
+ # ----------------------- Logging Setup ----------------------- #
26
+
27
+ def setup_logging(log_level=logging.INFO, log_file: str = None) -> logging.Logger:
28
+ """
29
+ Sets up a logger with console and optional file handlers.
30
+
31
+ Args:
32
+ log_level: Logging level (e.g., logging.INFO).
33
+ log_file: Optional path to log file.
34
+
35
+ Returns:
36
+ Configured logger instance.
37
+ """
38
+ logger = logging.getLogger("OmniCoreX")
39
+ logger.setLevel(log_level)
40
+ formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s")
41
+
42
+ # Remove existing handlers
43
+ for handler in logger.handlers[:]:
44
+ logger.removeHandler(handler)
45
+
46
+ # Console handler
47
+ ch = logging.StreamHandler(sys.stdout)
48
+ ch.setLevel(log_level)
49
+ ch.setFormatter(formatter)
50
+ logger.addHandler(ch)
51
+
52
+ # File handler if specified
53
+ if log_file:
54
+ fh = logging.FileHandler(log_file)
55
+ fh.setLevel(log_level)
56
+ fh.setFormatter(formatter)
57
+ logger.addHandler(fh)
58
+
59
+ return logger
60
+
61
+ # Global logger instance
62
+ logger = setup_logging()
63
+
64
+ # ----------------------- Configuration Loading ----------------------- #
65
+
66
+ def load_config_file(config_path: str) -> dict:
67
+ """
68
+ Loads a YAML or JSON configuration file.
69
+
70
+ Args:
71
+ config_path: Path to the config file.
72
+
73
+ Returns:
74
+ Dictionary of configuration parameters.
75
+ """
76
+ if not os.path.isfile(config_path):
77
+ raise FileNotFoundError(f"Config file not found: {config_path}")
78
+
79
+ ext = os.path.splitext(config_path)[1].lower()
80
+ with open(config_path, "r", encoding="utf-8") as f:
81
+ if ext in [".yaml", ".yml"]:
82
+ cfg = yaml.safe_load(f)
83
+ elif ext == ".json":
84
+ cfg = json.load(f)
85
+ else:
86
+ raise ValueError(f"Unsupported config format: {ext}")
87
+
88
+ return cfg
89
+
90
+ def merge_dicts(base: dict, override: dict) -> dict:
91
+ """
92
+ Deep merges two dictionaries, with the override taking precedence.
93
+
94
+ Args:
95
+ base: Base dictionary.
96
+ override: Dictionary with override values.
97
+
98
+ Returns:
99
+ Merged dictionary.
100
+ """
101
+ result = base.copy()
102
+ for k, v in override.items():
103
+ if k in result and isinstance(result[k], dict) and isinstance(v, dict):
104
+ result[k] = merge_dicts(result[k], v)
105
+ else:
106
+ result[k] = v
107
+ return result
108
+
109
+ # ----------------------- Seed Setting ----------------------- #
110
+
111
+ def set_seed(seed: int = 42):
112
+ """
113
+ Set seed for reproducibility across random, numpy and torch.
114
+
115
+ Args:
116
+ seed: Integer seed value.
117
+ """
118
+ random.seed(seed)
119
+ np.random.seed(seed)
120
+ torch.manual_seed(seed)
121
+ if torch.cuda.is_available():
122
+ torch.cuda.manual_seed_all(seed)
123
+ logger.info(f"Random seed set to {seed}")
124
+
125
+ # ----------------------- Timing Utilities ----------------------- #
126
+
127
+ def timeit(func):
128
+ """
129
+ Decorator to measure and log function execution time.
130
+
131
+ Usage:
132
+ @timeit
133
+ def my_function(...):
134
+ ...
135
+ """
136
+ def wrapper(*args, **kwargs):
137
+ start = time.time()
138
+ result = func(*args, **kwargs)
139
+ end = time.time()
140
+ logger.info(f"Function {func.__name__!r} executed in {(end - start):.4f}s")
141
+ return result
142
+ return wrapper
143
+
144
+ # ----------------------- Other Utility Functions ----------------------- #
145
+
146
+ def ensure_dir(dirname: str):
147
+ """
148
+ Creates directory if it does not exist.
149
+
150
+ Args:
151
+ dirname: Directory path to create.
152
+ """
153
+ if not os.path.exists(dirname):
154
+ os.makedirs(dirname)
155
+ logger.debug(f"Directory created: {dirname}")
156
+
157
+ def to_device(batch: dict, device: torch.device) -> dict:
158
+ """
159
+ Moves all tensor elements in batch dict to specified device.
160
+
161
+ Args:
162
+ batch: Dictionary with tensors.
163
+ device: Target torch device.
164
+
165
+ Returns:
166
+ Dictionary with tensors on device.
167
+ """
168
+ return {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
169
+
170
+ if __name__ == "__main__":
171
+ # Demo usage of utilities
172
+
173
+ set_seed(1234)
174
+ logger.info("This is a test log message.")
175
+
176
+ # Create dummy config files and test merging
177
+ base_cfg = {"model": {"layers": 12, "embed_dim": 256}, "training": {"batch_size": 32}}
178
+ override_cfg = {"model": {"layers": 24}, "training": {"learning_rate": 0.001}}
179
+
180
+ merged_cfg = merge_dicts(base_cfg, override_cfg)
181
+ logger.info(f"Merged config: {merged_cfg}")
182
+
183
+ # Test directory creation
184
+ test_dir = "./tmp_test_dir"
185
+ ensure_dir(test_dir)
186
+
187
+ # Test timing decorator
188
+ @timeit
189
+ def dummy_work():
190
+ import time; time.sleep(0.5)
191
+
192
+ dummy_work()
193
+