Zorrojurro commited on
Commit
c94d257
·
verified ·
1 Parent(s): 7205c68

Upload src/utils/config.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/utils/config.py +105 -0
src/utils/config.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration management for the Thermal Pattern Analysis project.
3
+
4
+ Loads YAML configs, provides attribute-style access, and handles
5
+ device selection + reproducibility seeding.
6
+ """
7
+
8
+ import os
9
+ import yaml
10
+ import torch
11
+ import random
12
+ import numpy as np
13
+ from pathlib import Path
14
+
15
+
16
+ class Config:
17
+ """Hierarchical configuration with attribute-style access."""
18
+
19
+ def __init__(self, config_dict: dict):
20
+ for key, value in config_dict.items():
21
+ if isinstance(value, dict):
22
+ setattr(self, key, Config(value))
23
+ elif isinstance(value, list):
24
+ setattr(self, key, [
25
+ Config(v) if isinstance(v, dict) else v for v in value
26
+ ])
27
+ else:
28
+ setattr(self, key, value)
29
+
30
+ def to_dict(self) -> dict:
31
+ """Convert back to a plain dictionary."""
32
+ result = {}
33
+ for key, value in self.__dict__.items():
34
+ if isinstance(value, Config):
35
+ result[key] = value.to_dict()
36
+ elif isinstance(value, list):
37
+ result[key] = [
38
+ v.to_dict() if isinstance(v, Config) else v for v in value
39
+ ]
40
+ else:
41
+ result[key] = value
42
+ return result
43
+
44
+ def __repr__(self):
45
+ return f"Config({self.to_dict()})"
46
+
47
+ def get(self, key, default=None):
48
+ """Safe attribute access with a default value."""
49
+ return getattr(self, key, default)
50
+
51
+
52
+ def load_config(config_path: str = "configs/config.yaml") -> Config:
53
+ """
54
+ Load configuration from a YAML file.
55
+
56
+ Args:
57
+ config_path: Path to the YAML configuration file.
58
+
59
+ Returns:
60
+ Config object with attribute-style access.
61
+ """
62
+ config_path = Path(config_path)
63
+ if not config_path.exists():
64
+ raise FileNotFoundError(f"Config file not found: {config_path}")
65
+
66
+ with open(config_path, "r") as f:
67
+ config_dict = yaml.safe_load(f)
68
+
69
+ return Config(config_dict)
70
+
71
+
72
+ def setup_device(config: Config) -> torch.device:
73
+ """
74
+ Determine the compute device based on config and availability.
75
+
76
+ Auto mode picks CUDA if available, otherwise CPU.
77
+ """
78
+ device_str = config.get("device", "auto")
79
+ if device_str == "auto":
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+ else:
82
+ device = torch.device(device_str)
83
+ return device
84
+
85
+
86
+ def set_seed(seed: int = 42):
87
+ """Set random seeds for reproducibility across all libraries."""
88
+ random.seed(seed)
89
+ np.random.seed(seed)
90
+ torch.manual_seed(seed)
91
+ if torch.cuda.is_available():
92
+ torch.cuda.manual_seed_all(seed)
93
+ torch.backends.cudnn.deterministic = True
94
+ torch.backends.cudnn.benchmark = False
95
+
96
+
97
+ def ensure_dirs(config: Config):
98
+ """Create all output directories specified in config.paths."""
99
+ paths = config.get("paths", None)
100
+ if paths is None:
101
+ return
102
+ for attr in ["checkpoints", "logs", "results", "visualizations"]:
103
+ dir_path = paths.get(attr, None)
104
+ if dir_path:
105
+ os.makedirs(dir_path, exist_ok=True)