Commit
·
b8956ed
1
Parent(s):
2fa0d24
removed redundant class from src/config
Browse files- src/config/config_loader.py +21 -21
- src/config/schemas.py +0 -77
src/config/config_loader.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import yaml
|
| 2 |
import logging
|
| 3 |
from pathlib import Path
|
| 4 |
-
from typing import Union, Tuple
|
| 5 |
from pydantic import ValidationError
|
| 6 |
|
| 7 |
from .schemas import SystemConfig, ModelConfig
|
|
@@ -13,58 +13,58 @@ class ConfigLoader:
|
|
| 13 |
def __init__(self):
|
| 14 |
self.logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
-
def load_and_validate(self, system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
|
| 17 |
"""
|
| 18 |
Load and validate configuration files from YAML files.
|
| 19 |
|
| 20 |
Args:
|
| 21 |
system_config_path: Path to YAML configuration file for OFDM-related parameters
|
| 22 |
-
model_config_path:
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
-
Tuple of (SystemConfig, ModelConfig): Validated configuration objects
|
| 26 |
|
| 27 |
Raises:
|
| 28 |
-
FileNotFoundError: If
|
| 29 |
ValueError: If configuration validation fails
|
| 30 |
"""
|
| 31 |
system_config_path = Path(system_config_path)
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
|
| 34 |
if not system_config_path.exists():
|
| 35 |
raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
|
| 36 |
|
| 37 |
-
if not model_config_path.exists():
|
| 38 |
-
raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
|
| 39 |
-
|
| 40 |
try:
|
| 41 |
with open(system_config_path, 'r') as f:
|
| 42 |
system_raw_config = yaml.safe_load(f)
|
| 43 |
except yaml.YAMLError as e:
|
| 44 |
raise ValueError(f"Failed to parse YAML file {system_config_path}: {e}")
|
| 45 |
|
| 46 |
-
try:
|
| 47 |
-
with open(model_config_path, 'r') as f:
|
| 48 |
-
model_raw_config = yaml.safe_load(f)
|
| 49 |
-
except yaml.YAMLError as e:
|
| 50 |
-
raise ValueError(f"Failed to parse YAML file {model_config_path}: {e}")
|
| 51 |
-
|
| 52 |
try:
|
| 53 |
system_config = SystemConfig(**system_raw_config)
|
| 54 |
self.logger.info(f"Successfully loaded system config from {system_config_path}")
|
| 55 |
except ValidationError as e:
|
| 56 |
raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
|
| 57 |
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
return system_config, model_config
|
| 65 |
|
| 66 |
|
| 67 |
-
def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
|
| 68 |
"""Convenience function to load and validate config."""
|
| 69 |
config_loader = ConfigLoader()
|
| 70 |
return config_loader.load_and_validate(system_config_path, model_config_path)
|
|
|
|
| 1 |
import yaml
|
| 2 |
import logging
|
| 3 |
from pathlib import Path
|
| 4 |
+
from typing import Union, Tuple, Optional
|
| 5 |
from pydantic import ValidationError
|
| 6 |
|
| 7 |
from .schemas import SystemConfig, ModelConfig
|
|
|
|
| 13 |
def __init__(self):
|
| 14 |
self.logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
+
def load_and_validate(self, system_config_path: Union[str, Path], model_config_path: Optional[Union[str, Path]] = None) -> Tuple[SystemConfig, Optional[ModelConfig]]:
|
| 17 |
"""
|
| 18 |
Load and validate configuration files from YAML files.
|
| 19 |
|
| 20 |
Args:
|
| 21 |
system_config_path: Path to YAML configuration file for OFDM-related parameters
|
| 22 |
+
model_config_path: Optional path to YAML configuration file for model-related parameters
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
+
Tuple of (SystemConfig, Optional[ModelConfig]): Validated configuration objects
|
| 26 |
|
| 27 |
Raises:
|
| 28 |
+
FileNotFoundError: If the system config file doesn't exist
|
| 29 |
ValueError: If configuration validation fails
|
| 30 |
"""
|
| 31 |
system_config_path = Path(system_config_path)
|
| 32 |
+
model_config = None
|
| 33 |
+
if model_config_path is not None:
|
| 34 |
+
model_config_path = Path(model_config_path)
|
| 35 |
|
| 36 |
if not system_config_path.exists():
|
| 37 |
raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
|
| 38 |
|
|
|
|
|
|
|
|
|
|
| 39 |
try:
|
| 40 |
with open(system_config_path, 'r') as f:
|
| 41 |
system_raw_config = yaml.safe_load(f)
|
| 42 |
except yaml.YAMLError as e:
|
| 43 |
raise ValueError(f"Failed to parse YAML file {system_config_path}: {e}")
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
try:
|
| 46 |
system_config = SystemConfig(**system_raw_config)
|
| 47 |
self.logger.info(f"Successfully loaded system config from {system_config_path}")
|
| 48 |
except ValidationError as e:
|
| 49 |
raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
|
| 50 |
|
| 51 |
+
# Only load model config if path is provided and file exists
|
| 52 |
+
if model_config_path is not None and model_config_path.exists():
|
| 53 |
+
try:
|
| 54 |
+
with open(model_config_path, 'r') as f:
|
| 55 |
+
model_raw_config = yaml.safe_load(f)
|
| 56 |
+
except yaml.YAMLError as e:
|
| 57 |
+
raise ValueError(f"Failed to parse YAML file {model_config_path}: {e}")
|
| 58 |
+
try:
|
| 59 |
+
model_config = ModelConfig(**model_raw_config)
|
| 60 |
+
self.logger.info(f"Successfully loaded model config from {model_config_path}")
|
| 61 |
+
except ValidationError as e:
|
| 62 |
+
raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
|
| 63 |
|
| 64 |
return system_config, model_config
|
| 65 |
|
| 66 |
|
| 67 |
+
def load_config(system_config_path: Union[str, Path], model_config_path: Optional[Union[str, Path]] = None) -> Tuple[SystemConfig, Optional[ModelConfig]]:
|
| 68 |
"""Convenience function to load and validate config."""
|
| 69 |
config_loader = ConfigLoader()
|
| 70 |
return config_loader.load_and_validate(system_config_path, model_config_path)
|
src/config/schemas.py
CHANGED
|
@@ -13,83 +13,6 @@ class PilotParams(BaseModel):
|
|
| 13 |
num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
|
| 14 |
|
| 15 |
|
| 16 |
-
class ModelParams(BaseModel):
|
| 17 |
-
patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
|
| 18 |
-
num_layers: int = Field(..., gt=0, description="Number of transformer layers")
|
| 19 |
-
model_dim: int = Field(..., gt=0, description="Model dimension")
|
| 20 |
-
num_head: int = Field(..., gt=0, description="Number of attention heads")
|
| 21 |
-
activation: str = Field(default="gelu", description="Activation function")
|
| 22 |
-
dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate")
|
| 23 |
-
max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
|
| 24 |
-
pos_encoding_type: str = Field(default="learnable", description="Position encoding type")
|
| 25 |
-
adaptive_token_length: int = Field(default=6, gt=0, description="Adaptive token length")
|
| 26 |
-
channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
|
| 27 |
-
default=None,
|
| 28 |
-
description="Hidden sizes for channel adaptation layers"
|
| 29 |
-
)
|
| 30 |
-
device: str = Field(default="cpu", description="Device to use")
|
| 31 |
-
|
| 32 |
-
@model_validator(mode='after')
|
| 33 |
-
def validate_device(self) -> Self:
|
| 34 |
-
"""Validate that the specified device is available."""
|
| 35 |
-
device_str = self.device.lower()
|
| 36 |
-
|
| 37 |
-
# Handle 'auto' case - automatically select best available device
|
| 38 |
-
if device_str == 'auto':
|
| 39 |
-
if torch.cuda.is_available():
|
| 40 |
-
self.device = 'cuda'
|
| 41 |
-
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 42 |
-
self.device = 'mps' # Apple Silicon
|
| 43 |
-
else:
|
| 44 |
-
self.device = 'cpu'
|
| 45 |
-
return self
|
| 46 |
-
|
| 47 |
-
# Validate CPU
|
| 48 |
-
if device_str == 'cpu':
|
| 49 |
-
return self
|
| 50 |
-
|
| 51 |
-
# Validate CUDA devices
|
| 52 |
-
if device_str.startswith('cuda'):
|
| 53 |
-
if not torch.cuda.is_available():
|
| 54 |
-
raise ValueError("CUDA is not available on this system")
|
| 55 |
-
|
| 56 |
-
# Handle specific CUDA device (e.g., 'cuda:0', 'cuda:1')
|
| 57 |
-
if ':' in device_str:
|
| 58 |
-
try:
|
| 59 |
-
device_id = int(device_str.split(':')[1])
|
| 60 |
-
if device_id >= torch.cuda.device_count():
|
| 61 |
-
available_devices = list(range(torch.cuda.device_count()))
|
| 62 |
-
raise ValueError(
|
| 63 |
-
f"CUDA device {device_id} not available. "
|
| 64 |
-
f"Available CUDA devices: {available_devices}"
|
| 65 |
-
)
|
| 66 |
-
except (ValueError, IndexError) as e:
|
| 67 |
-
if "invalid literal" in str(e):
|
| 68 |
-
raise ValueError(f"Invalid CUDA device format: {device_str}")
|
| 69 |
-
raise
|
| 70 |
-
|
| 71 |
-
return self
|
| 72 |
-
|
| 73 |
-
# Validate MPS (Apple Silicon)
|
| 74 |
-
if device_str == 'mps':
|
| 75 |
-
if not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
|
| 76 |
-
raise ValueError("MPS is not available on this system")
|
| 77 |
-
return self
|
| 78 |
-
|
| 79 |
-
# If we get here, the device is not recognized
|
| 80 |
-
available_devices = ['cpu']
|
| 81 |
-
if torch.cuda.is_available():
|
| 82 |
-
cuda_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
|
| 83 |
-
available_devices.extend(['cuda'] + cuda_devices)
|
| 84 |
-
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
|
| 85 |
-
available_devices.append('mps')
|
| 86 |
-
|
| 87 |
-
raise ValueError(
|
| 88 |
-
f"Unsupported device: '{self.device}'. "
|
| 89 |
-
f"Available devices: {available_devices}"
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
|
| 93 |
class SystemConfig(BaseModel):
|
| 94 |
ofdm: OFDMParams
|
| 95 |
pilot: PilotParams
|
|
|
|
| 13 |
num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class SystemConfig(BaseModel):
|
| 17 |
ofdm: OFDMParams
|
| 18 |
pilot: PilotParams
|