BerkIGuler commited on
Commit
2ef7f80
·
1 Parent(s): 61f9e59

added basic config handling for ofdm parameters and the linear model from the paper

Browse files
config/system_config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ofdm:
2
+ num_scs: 120
3
+ num_symbols: 14
4
+
5
+ pilot:
6
+ num_scs: 12
7
+ num_symbols: 2
requirements.txt CHANGED
@@ -0,0 +1,2 @@
 
 
 
1
+ torch
2
+ pydantic
src/config/__init__.py ADDED
File without changes
src/config/config_loader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Union
5
+ from pydantic import ValidationError
6
+
7
+ from .schemas import SystemConfig
8
+
9
+
10
+ class ConfigLoader:
11
+ """Simple configuration loader with validation"""
12
+
13
+ @staticmethod
14
+ def load_and_validate(config_path: Union[str, Path]) -> SystemConfig:
15
+ """
16
+ Load and validate configuration from YAML file.
17
+
18
+ Args:
19
+ config_path: Path to YAML configuration file
20
+
21
+ Returns:
22
+ ModelConfig: Validated configuration object
23
+
24
+ Raises:
25
+ FileNotFoundError: If config file doesn't exist
26
+ ValueError: If configuration validation fails
27
+ """
28
+ config_path = Path(config_path)
29
+
30
+ if not config_path.exists():
31
+ raise FileNotFoundError(f"Configuration file not found: {config_path}")
32
+
33
+ try:
34
+ with open(config_path, 'r') as f:
35
+ raw_config = yaml.safe_load(f)
36
+ except yaml.YAMLError as e:
37
+ raise ValueError(f"Failed to parse YAML file {config_path}: {e}")
38
+
39
+ try:
40
+ config = SystemConfig(**raw_config)
41
+ logging.getLogger(__name__).info(f"Successfully loaded config from {config_path}")
42
+ return config
43
+ except ValidationError as e:
44
+ raise ValueError(f"Configuration validation failed:\n{e}")
45
+
46
+
47
+ def load_config(config_path: Union[str, Path]) -> SystemConfig:
48
+ """Convenience function to load and validate config."""
49
+ return ConfigLoader.load_and_validate(config_path)
src/config/schemas.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, model_validator
2
+ from typing import Self
3
+
4
+
5
+ class OFDMParams(BaseModel):
6
+ num_scs: int = Field(..., gt=0, description="Number of sub-carriers")
7
+ num_symbols: int = Field(..., gt=0, description="Number of OFDM symbols")
8
+
9
+
10
+ class PilotParams(BaseModel):
11
+ num_scs: int = Field(..., gt=0, description="Number of pilots across sub-carriers")
12
+ num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
13
+
14
+
15
+ class SystemConfig(BaseModel):
16
+ ofdm: OFDMParams
17
+ pilot: PilotParams
18
+ device: str = Field(
19
+ default="cpu",
20
+ pattern=r"^(cpu|cuda(:\d+)?)$", # Updated regex to allow cuda:x
21
+ description="Target device (cpu, cuda, or cuda:x where x is device index)"
22
+ )
23
+
24
+ @model_validator(mode='after')
25
+ def validate_pilot_constraints(self) -> Self:
26
+ """Ensure pilot parameters don't exceed OFDM parameters."""
27
+ if self.pilot.num_scs > self.ofdm.num_scs:
28
+ raise ValueError(
29
+ f"Pilot sub-carriers ({self.pilot.num_scs}) cannot exceed "
30
+ f"OFDM sub-carriers ({self.ofdm.num_scs})"
31
+ )
32
+
33
+ if self.pilot.num_symbols > self.ofdm.num_symbols:
34
+ raise ValueError(
35
+ f"Pilot symbols ({self.pilot.num_symbols}) cannot exceed "
36
+ f"OFDM symbols ({self.ofdm.num_symbols})"
37
+ )
38
+ return self
39
+
40
+ model_config = {"extra": "forbid"}
src/models/__init__.py ADDED
File without changes
src/models/linear.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Learned linear estimator module for OFDM channel estimation.
3
+
4
+ This module implements an estimator for transforming channel estimates at
5
+ pilot signals to complete channel estimates using a learned linear transformation.
6
+ """
7
+
8
+ from typing import Tuple
9
+ import logging
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from src.config.schemas import SystemConfig
14
+
15
+
16
+ class LinearEstimator(nn.Module):
17
+ """Learned MMSE estimator.
18
+
19
+ Attributes:
20
+ device (torch.device): Target device for computation
21
+ config (SystemConfig): Validated configuration object
22
+ ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (height, width)
23
+ height (int): number of sub-carriers
24
+ width (int): number of OFDM symbols
25
+ pilot_size (Tuple[int, int]): Dimensions of pilot signal as (height, width)
26
+ height (int): number of pilots across sub-carriers
27
+ width (int): number of pilots across OFDM symbols
28
+ """
29
+
30
+ def __init__(self, config: SystemConfig) -> None:
31
+ """Initialize the MMSE estimator.
32
+
33
+ Args:
34
+ config: Validated SystemConfig object containing OFDM system parameters
35
+ """
36
+ super().__init__()
37
+
38
+ self.config = config
39
+ self.device = torch.device(config.device)
40
+ self.logger = logging.getLogger(__name__)
41
+
42
+ # Extract dimensions from validated config
43
+ self.ofdm_size = (config.ofdm.num_scs, config.ofdm.num_symbols)
44
+ self.pilot_size = (config.pilot.num_scs, config.pilot.num_symbols)
45
+
46
+ # Calculate feature dimensions
47
+ in_feature_dim = config.pilot.num_scs * config.pilot.num_symbols
48
+ out_feature_dim = config.ofdm.num_scs * config.ofdm.num_symbols
49
+
50
+ self.logger.info(f"Initializing LinearEstimator:")
51
+ self.logger.info(f" OFDM size: {self.ofdm_size}")
52
+ self.logger.info(f" Pilot size: {self.pilot_size}")
53
+ self.logger.info(f" Input features: {in_feature_dim}")
54
+ self.logger.info(f" Output features: {out_feature_dim}")
55
+ self.logger.info(f" Device: {self.device}")
56
+
57
+ # Create linear layer
58
+ self.linear = nn.Linear(in_feature_dim, out_feature_dim)
59
+ self.to(self.device)
60
+
61
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ """Forward pass of the MMSE estimator.
63
+
64
+ Args:
65
+ x: Input tensor containing pilot signals with shape
66
+ (batch_size, pilot_size[0], pilot_size[1])
67
+
68
+ Returns:
69
+ Estimated OFDM signal tensor with shape
70
+ (batch_size, ofdm_size[0], ofdm_size[1])
71
+ """
72
+ # pytorch does nothin if input is already on correct device
73
+ x = x.to(self.device)
74
+ self.logger.debug(f"Input shape: {x.size()}")
75
+
76
+ # Validate input shape
77
+ expected_shape = (x.size(0), self.pilot_size[0], self.pilot_size[1])
78
+ if x.size() != expected_shape:
79
+ raise ValueError(
80
+ f"Expected input shape {expected_shape}, got {x.size()}"
81
+ )
82
+
83
+ # Flatten input for linear transformation
84
+ x = torch.flatten(x, start_dim=1)
85
+ self.logger.debug(f"Flattened shape: {x.size()}")
86
+
87
+ # Apply linear transformation
88
+ x = self.linear(x)
89
+ self.logger.debug(f"Linear output shape: {x.size()}")
90
+
91
+ # Reshape to OFDM dimensions
92
+ x = x.reshape(-1, self.ofdm_size[0], self.ofdm_size[1])
93
+ self.logger.debug(f"Reshaped output shape: {x.size()}")
94
+
95
+ return x
96
+
97
+ def get_config(self) -> SystemConfig:
98
+ """Get the configuration used by this estimator.
99
+
100
+ Returns:
101
+ SystemConfig: The configuration object
102
+ """
103
+ return self.config
104
+
105
+ def __repr__(self) -> str:
106
+ """String representation of the estimator."""
107
+ return (
108
+ f"LinearEstimator(\n"
109
+ f" ofdm_size={self.ofdm_size},\n"
110
+ f" pilot_size={self.pilot_size},\n"
111
+ f" device={self.device}\n"
112
+ f")"
113
+ )