BerkIGuler commited on
Commit
b8956ed
·
1 Parent(s): 2fa0d24

removed redundant class from src/config

Browse files
Files changed (2) hide show
  1. src/config/config_loader.py +21 -21
  2. src/config/schemas.py +0 -77
src/config/config_loader.py CHANGED
@@ -1,7 +1,7 @@
1
  import yaml
2
  import logging
3
  from pathlib import Path
4
- from typing import Union, Tuple
5
  from pydantic import ValidationError
6
 
7
  from .schemas import SystemConfig, ModelConfig
@@ -13,58 +13,58 @@ class ConfigLoader:
13
  def __init__(self):
14
  self.logger = logging.getLogger(__name__)
15
 
16
- def load_and_validate(self, system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
17
  """
18
  Load and validate configuration files from YAML files.
19
 
20
  Args:
21
  system_config_path: Path to YAML configuration file for OFDM-related parameters
22
- model_config_path: Path to YAML configuration file for model-related parameters
23
 
24
  Returns:
25
- Tuple of (SystemConfig, ModelConfig): Validated configuration objects
26
 
27
  Raises:
28
- FileNotFoundError: If one of the config files doesn't exist
29
  ValueError: If configuration validation fails
30
  """
31
  system_config_path = Path(system_config_path)
32
- model_config_path = Path(model_config_path)
 
 
33
 
34
  if not system_config_path.exists():
35
  raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
36
 
37
- if not model_config_path.exists():
38
- raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
39
-
40
  try:
41
  with open(system_config_path, 'r') as f:
42
  system_raw_config = yaml.safe_load(f)
43
  except yaml.YAMLError as e:
44
  raise ValueError(f"Failed to parse YAML file {system_config_path}: {e}")
45
 
46
- try:
47
- with open(model_config_path, 'r') as f:
48
- model_raw_config = yaml.safe_load(f)
49
- except yaml.YAMLError as e:
50
- raise ValueError(f"Failed to parse YAML file {model_config_path}: {e}")
51
-
52
  try:
53
  system_config = SystemConfig(**system_raw_config)
54
  self.logger.info(f"Successfully loaded system config from {system_config_path}")
55
  except ValidationError as e:
56
  raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
57
 
58
- try:
59
- model_config = ModelConfig(**model_raw_config)
60
- self.logger.info(f"Successfully loaded model config from {model_config_path}")
61
- except ValidationError as e:
62
- raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
 
 
 
 
 
 
 
63
 
64
  return system_config, model_config
65
 
66
 
67
- def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
68
  """Convenience function to load and validate config."""
69
  config_loader = ConfigLoader()
70
  return config_loader.load_and_validate(system_config_path, model_config_path)
 
1
  import yaml
2
  import logging
3
  from pathlib import Path
4
+ from typing import Union, Tuple, Optional
5
  from pydantic import ValidationError
6
 
7
  from .schemas import SystemConfig, ModelConfig
 
13
  def __init__(self):
14
  self.logger = logging.getLogger(__name__)
15
 
16
+ def load_and_validate(self, system_config_path: Union[str, Path], model_config_path: Optional[Union[str, Path]] = None) -> Tuple[SystemConfig, Optional[ModelConfig]]:
17
  """
18
  Load and validate configuration files from YAML files.
19
 
20
  Args:
21
  system_config_path: Path to YAML configuration file for OFDM-related parameters
22
+ model_config_path: Optional path to YAML configuration file for model-related parameters
23
 
24
  Returns:
25
+ Tuple of (SystemConfig, Optional[ModelConfig]): Validated configuration objects
26
 
27
  Raises:
28
+ FileNotFoundError: If the system config file doesn't exist
29
  ValueError: If configuration validation fails
30
  """
31
  system_config_path = Path(system_config_path)
32
+ model_config = None
33
+ if model_config_path is not None:
34
+ model_config_path = Path(model_config_path)
35
 
36
  if not system_config_path.exists():
37
  raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
38
 
 
 
 
39
  try:
40
  with open(system_config_path, 'r') as f:
41
  system_raw_config = yaml.safe_load(f)
42
  except yaml.YAMLError as e:
43
  raise ValueError(f"Failed to parse YAML file {system_config_path}: {e}")
44
 
 
 
 
 
 
 
45
  try:
46
  system_config = SystemConfig(**system_raw_config)
47
  self.logger.info(f"Successfully loaded system config from {system_config_path}")
48
  except ValidationError as e:
49
  raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
50
 
51
+ # Only load model config if path is provided and file exists
52
+ if model_config_path is not None and model_config_path.exists():
53
+ try:
54
+ with open(model_config_path, 'r') as f:
55
+ model_raw_config = yaml.safe_load(f)
56
+ except yaml.YAMLError as e:
57
+ raise ValueError(f"Failed to parse YAML file {model_config_path}: {e}")
58
+ try:
59
+ model_config = ModelConfig(**model_raw_config)
60
+ self.logger.info(f"Successfully loaded model config from {model_config_path}")
61
+ except ValidationError as e:
62
+ raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
63
 
64
  return system_config, model_config
65
 
66
 
67
+ def load_config(system_config_path: Union[str, Path], model_config_path: Optional[Union[str, Path]] = None) -> Tuple[SystemConfig, Optional[ModelConfig]]:
68
  """Convenience function to load and validate config."""
69
  config_loader = ConfigLoader()
70
  return config_loader.load_and_validate(system_config_path, model_config_path)
src/config/schemas.py CHANGED
@@ -13,83 +13,6 @@ class PilotParams(BaseModel):
13
  num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
14
 
15
 
16
- class ModelParams(BaseModel):
17
- patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
18
- num_layers: int = Field(..., gt=0, description="Number of transformer layers")
19
- model_dim: int = Field(..., gt=0, description="Model dimension")
20
- num_head: int = Field(..., gt=0, description="Number of attention heads")
21
- activation: str = Field(default="gelu", description="Activation function")
22
- dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate")
23
- max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
24
- pos_encoding_type: str = Field(default="learnable", description="Position encoding type")
25
- adaptive_token_length: int = Field(default=6, gt=0, description="Adaptive token length")
26
- channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
27
- default=None,
28
- description="Hidden sizes for channel adaptation layers"
29
- )
30
- device: str = Field(default="cpu", description="Device to use")
31
-
32
- @model_validator(mode='after')
33
- def validate_device(self) -> Self:
34
- """Validate that the specified device is available."""
35
- device_str = self.device.lower()
36
-
37
- # Handle 'auto' case - automatically select best available device
38
- if device_str == 'auto':
39
- if torch.cuda.is_available():
40
- self.device = 'cuda'
41
- elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
42
- self.device = 'mps' # Apple Silicon
43
- else:
44
- self.device = 'cpu'
45
- return self
46
-
47
- # Validate CPU
48
- if device_str == 'cpu':
49
- return self
50
-
51
- # Validate CUDA devices
52
- if device_str.startswith('cuda'):
53
- if not torch.cuda.is_available():
54
- raise ValueError("CUDA is not available on this system")
55
-
56
- # Handle specific CUDA device (e.g., 'cuda:0', 'cuda:1')
57
- if ':' in device_str:
58
- try:
59
- device_id = int(device_str.split(':')[1])
60
- if device_id >= torch.cuda.device_count():
61
- available_devices = list(range(torch.cuda.device_count()))
62
- raise ValueError(
63
- f"CUDA device {device_id} not available. "
64
- f"Available CUDA devices: {available_devices}"
65
- )
66
- except (ValueError, IndexError) as e:
67
- if "invalid literal" in str(e):
68
- raise ValueError(f"Invalid CUDA device format: {device_str}")
69
- raise
70
-
71
- return self
72
-
73
- # Validate MPS (Apple Silicon)
74
- if device_str == 'mps':
75
- if not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
76
- raise ValueError("MPS is not available on this system")
77
- return self
78
-
79
- # If we get here, the device is not recognized
80
- available_devices = ['cpu']
81
- if torch.cuda.is_available():
82
- cuda_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
83
- available_devices.extend(['cuda'] + cuda_devices)
84
- if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
85
- available_devices.append('mps')
86
-
87
- raise ValueError(
88
- f"Unsupported device: '{self.device}'. "
89
- f"Available devices: {available_devices}"
90
- )
91
-
92
-
93
  class SystemConfig(BaseModel):
94
  ofdm: OFDMParams
95
  pilot: PilotParams
 
13
  num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class SystemConfig(BaseModel):
17
  ofdm: OFDMParams
18
  pilot: PilotParams