Commit
·
2ef7f80
1
Parent(s):
61f9e59
added basic config handling for ofdm parameters and the linear model from the paper
Browse files- config/system_config.yaml +7 -0
- requirements.txt +2 -0
- src/config/__init__.py +0 -0
- src/config/config_loader.py +49 -0
- src/config/schemas.py +40 -0
- src/models/__init__.py +0 -0
- src/models/linear.py +113 -0
config/system_config.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ofdm:
|
| 2 |
+
num_scs: 120
|
| 3 |
+
num_symbols: 14
|
| 4 |
+
|
| 5 |
+
pilot:
|
| 6 |
+
num_scs: 12
|
| 7 |
+
num_symbols: 2
|
requirements.txt
CHANGED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
pydantic
|
src/config/__init__.py
ADDED
|
File without changes
|
src/config/config_loader.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
import logging
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
from pydantic import ValidationError
|
| 6 |
+
|
| 7 |
+
from .schemas import SystemConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ConfigLoader:
|
| 11 |
+
"""Simple configuration loader with validation"""
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def load_and_validate(config_path: Union[str, Path]) -> SystemConfig:
|
| 15 |
+
"""
|
| 16 |
+
Load and validate configuration from YAML file.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
config_path: Path to YAML configuration file
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
ModelConfig: Validated configuration object
|
| 23 |
+
|
| 24 |
+
Raises:
|
| 25 |
+
FileNotFoundError: If config file doesn't exist
|
| 26 |
+
ValueError: If configuration validation fails
|
| 27 |
+
"""
|
| 28 |
+
config_path = Path(config_path)
|
| 29 |
+
|
| 30 |
+
if not config_path.exists():
|
| 31 |
+
raise FileNotFoundError(f"Configuration file not found: {config_path}")
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
with open(config_path, 'r') as f:
|
| 35 |
+
raw_config = yaml.safe_load(f)
|
| 36 |
+
except yaml.YAMLError as e:
|
| 37 |
+
raise ValueError(f"Failed to parse YAML file {config_path}: {e}")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
config = SystemConfig(**raw_config)
|
| 41 |
+
logging.getLogger(__name__).info(f"Successfully loaded config from {config_path}")
|
| 42 |
+
return config
|
| 43 |
+
except ValidationError as e:
|
| 44 |
+
raise ValueError(f"Configuration validation failed:\n{e}")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_config(config_path: Union[str, Path]) -> SystemConfig:
|
| 48 |
+
"""Convenience function to load and validate config."""
|
| 49 |
+
return ConfigLoader.load_and_validate(config_path)
|
src/config/schemas.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, model_validator
|
| 2 |
+
from typing import Self
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class OFDMParams(BaseModel):
|
| 6 |
+
num_scs: int = Field(..., gt=0, description="Number of sub-carriers")
|
| 7 |
+
num_symbols: int = Field(..., gt=0, description="Number of OFDM symbols")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PilotParams(BaseModel):
|
| 11 |
+
num_scs: int = Field(..., gt=0, description="Number of pilots across sub-carriers")
|
| 12 |
+
num_symbols: int = Field(..., gt=0, description="Number of pilots across OFDM symbols")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SystemConfig(BaseModel):
|
| 16 |
+
ofdm: OFDMParams
|
| 17 |
+
pilot: PilotParams
|
| 18 |
+
device: str = Field(
|
| 19 |
+
default="cpu",
|
| 20 |
+
pattern=r"^(cpu|cuda(:\d+)?)$", # Updated regex to allow cuda:x
|
| 21 |
+
description="Target device (cpu, cuda, or cuda:x where x is device index)"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
@model_validator(mode='after')
|
| 25 |
+
def validate_pilot_constraints(self) -> Self:
|
| 26 |
+
"""Ensure pilot parameters don't exceed OFDM parameters."""
|
| 27 |
+
if self.pilot.num_scs > self.ofdm.num_scs:
|
| 28 |
+
raise ValueError(
|
| 29 |
+
f"Pilot sub-carriers ({self.pilot.num_scs}) cannot exceed "
|
| 30 |
+
f"OFDM sub-carriers ({self.ofdm.num_scs})"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
if self.pilot.num_symbols > self.ofdm.num_symbols:
|
| 34 |
+
raise ValueError(
|
| 35 |
+
f"Pilot symbols ({self.pilot.num_symbols}) cannot exceed "
|
| 36 |
+
f"OFDM symbols ({self.ofdm.num_symbols})"
|
| 37 |
+
)
|
| 38 |
+
return self
|
| 39 |
+
|
| 40 |
+
model_config = {"extra": "forbid"}
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/linear.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Learned linear estimator module for OFDM channel estimation.
|
| 3 |
+
|
| 4 |
+
This module implements an estimator for transforming channel estimates at
|
| 5 |
+
pilot signals to complete channel estimates using a learned linear transformation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
import logging
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
|
| 13 |
+
from src.config.schemas import SystemConfig
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LinearEstimator(nn.Module):
|
| 17 |
+
"""Learned MMSE estimator.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
device (torch.device): Target device for computation
|
| 21 |
+
config (SystemConfig): Validated configuration object
|
| 22 |
+
ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (height, width)
|
| 23 |
+
height (int): number of sub-carriers
|
| 24 |
+
width (int): number of OFDM symbols
|
| 25 |
+
pilot_size (Tuple[int, int]): Dimensions of pilot signal as (height, width)
|
| 26 |
+
height (int): number of pilots across sub-carriers
|
| 27 |
+
width (int): number of pilots across OFDM symbols
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, config: SystemConfig) -> None:
|
| 31 |
+
"""Initialize the MMSE estimator.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
config: Validated SystemConfig object containing OFDM system parameters
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
self.config = config
|
| 39 |
+
self.device = torch.device(config.device)
|
| 40 |
+
self.logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
# Extract dimensions from validated config
|
| 43 |
+
self.ofdm_size = (config.ofdm.num_scs, config.ofdm.num_symbols)
|
| 44 |
+
self.pilot_size = (config.pilot.num_scs, config.pilot.num_symbols)
|
| 45 |
+
|
| 46 |
+
# Calculate feature dimensions
|
| 47 |
+
in_feature_dim = config.pilot.num_scs * config.pilot.num_symbols
|
| 48 |
+
out_feature_dim = config.ofdm.num_scs * config.ofdm.num_symbols
|
| 49 |
+
|
| 50 |
+
self.logger.info(f"Initializing LinearEstimator:")
|
| 51 |
+
self.logger.info(f" OFDM size: {self.ofdm_size}")
|
| 52 |
+
self.logger.info(f" Pilot size: {self.pilot_size}")
|
| 53 |
+
self.logger.info(f" Input features: {in_feature_dim}")
|
| 54 |
+
self.logger.info(f" Output features: {out_feature_dim}")
|
| 55 |
+
self.logger.info(f" Device: {self.device}")
|
| 56 |
+
|
| 57 |
+
# Create linear layer
|
| 58 |
+
self.linear = nn.Linear(in_feature_dim, out_feature_dim)
|
| 59 |
+
self.to(self.device)
|
| 60 |
+
|
| 61 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 62 |
+
"""Forward pass of the MMSE estimator.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
x: Input tensor containing pilot signals with shape
|
| 66 |
+
(batch_size, pilot_size[0], pilot_size[1])
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Estimated OFDM signal tensor with shape
|
| 70 |
+
(batch_size, ofdm_size[0], ofdm_size[1])
|
| 71 |
+
"""
|
| 72 |
+
# pytorch does nothin if input is already on correct device
|
| 73 |
+
x = x.to(self.device)
|
| 74 |
+
self.logger.debug(f"Input shape: {x.size()}")
|
| 75 |
+
|
| 76 |
+
# Validate input shape
|
| 77 |
+
expected_shape = (x.size(0), self.pilot_size[0], self.pilot_size[1])
|
| 78 |
+
if x.size() != expected_shape:
|
| 79 |
+
raise ValueError(
|
| 80 |
+
f"Expected input shape {expected_shape}, got {x.size()}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Flatten input for linear transformation
|
| 84 |
+
x = torch.flatten(x, start_dim=1)
|
| 85 |
+
self.logger.debug(f"Flattened shape: {x.size()}")
|
| 86 |
+
|
| 87 |
+
# Apply linear transformation
|
| 88 |
+
x = self.linear(x)
|
| 89 |
+
self.logger.debug(f"Linear output shape: {x.size()}")
|
| 90 |
+
|
| 91 |
+
# Reshape to OFDM dimensions
|
| 92 |
+
x = x.reshape(-1, self.ofdm_size[0], self.ofdm_size[1])
|
| 93 |
+
self.logger.debug(f"Reshaped output shape: {x.size()}")
|
| 94 |
+
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
def get_config(self) -> SystemConfig:
|
| 98 |
+
"""Get the configuration used by this estimator.
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
SystemConfig: The configuration object
|
| 102 |
+
"""
|
| 103 |
+
return self.config
|
| 104 |
+
|
| 105 |
+
def __repr__(self) -> str:
|
| 106 |
+
"""String representation of the estimator."""
|
| 107 |
+
return (
|
| 108 |
+
f"LinearEstimator(\n"
|
| 109 |
+
f" ofdm_size={self.ofdm_size},\n"
|
| 110 |
+
f" pilot_size={self.pilot_size},\n"
|
| 111 |
+
f" device={self.device}\n"
|
| 112 |
+
f")"
|
| 113 |
+
)
|