Commit
·
687eaba
1
Parent(s):
9727e5e
fixes on src/models
Browse files- config/linear.yaml +2 -0
- src/config/config_loader.py +20 -21
- src/config/schemas.py +59 -52
- src/main.py +4 -3
- src/main/trainer.py +3 -7
- src/models/blocks/enhancers.py +2 -2
- src/models/blocks/patch_processors.py +6 -6
- src/models/fortitran.py +21 -15
- src/models/linear.py +22 -26
config/linear.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_type: 'linear'
|
| 2 |
+
device: 'cpu'
|
src/config/config_loader.py
CHANGED
|
@@ -13,31 +13,30 @@ class ConfigLoader:
|
|
| 13 |
def __init__(self):
|
| 14 |
self.logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
-
def load_and_validate(self, system_config_path: Union[str, Path], model_config_path:
|
| 17 |
"""
|
| 18 |
Load and validate configuration files from YAML files.
|
| 19 |
|
| 20 |
Args:
|
| 21 |
system_config_path: Path to YAML configuration file for OFDM-related parameters
|
| 22 |
-
model_config_path:
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
-
Tuple of (SystemConfig,
|
| 26 |
|
| 27 |
Raises:
|
| 28 |
-
FileNotFoundError: If
|
| 29 |
ValueError: If configuration validation fails
|
| 30 |
"""
|
| 31 |
system_config_path = Path(system_config_path)
|
| 32 |
-
|
| 33 |
-
# certain models may not have a model config
|
| 34 |
-
model_config = None
|
| 35 |
-
if model_config_path is not None:
|
| 36 |
-
model_config_path = Path(model_config_path)
|
| 37 |
|
| 38 |
if not system_config_path.exists():
|
| 39 |
raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
|
| 40 |
|
|
|
|
|
|
|
|
|
|
| 41 |
try:
|
| 42 |
with open(system_config_path, 'r') as f:
|
| 43 |
system_raw_config = yaml.safe_load(f)
|
|
@@ -50,22 +49,22 @@ class ConfigLoader:
|
|
| 50 |
except ValidationError as e:
|
| 51 |
raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
|
| 65 |
return system_config, model_config
|
| 66 |
|
| 67 |
|
| 68 |
-
def load_config(system_config_path: Union[str, Path], model_config_path:
|
| 69 |
"""Convenience function to load and validate config."""
|
| 70 |
config_loader = ConfigLoader()
|
| 71 |
return config_loader.load_and_validate(system_config_path, model_config_path)
|
|
|
|
| 13 |
def __init__(self):
|
| 14 |
self.logger = logging.getLogger(__name__)
|
| 15 |
|
| 16 |
+
def load_and_validate(self, system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
|
| 17 |
"""
|
| 18 |
Load and validate configuration files from YAML files.
|
| 19 |
|
| 20 |
Args:
|
| 21 |
system_config_path: Path to YAML configuration file for OFDM-related parameters
|
| 22 |
+
model_config_path: Path to YAML configuration file for model-related parameters
|
| 23 |
|
| 24 |
Returns:
|
| 25 |
+
Tuple of (SystemConfig, ModelConfig): Validated configuration objects
|
| 26 |
|
| 27 |
Raises:
|
| 28 |
+
FileNotFoundError: If either config file doesn't exist
|
| 29 |
ValueError: If configuration validation fails
|
| 30 |
"""
|
| 31 |
system_config_path = Path(system_config_path)
|
| 32 |
+
model_config_path = Path(model_config_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
if not system_config_path.exists():
|
| 35 |
raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
|
| 36 |
|
| 37 |
+
if not model_config_path.exists():
|
| 38 |
+
raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
|
| 39 |
+
|
| 40 |
try:
|
| 41 |
with open(system_config_path, 'r') as f:
|
| 42 |
system_raw_config = yaml.safe_load(f)
|
|
|
|
| 49 |
except ValidationError as e:
|
| 50 |
raise ValueError(f"System configuration validation for {system_config_path} failed:\n{e}")
|
| 51 |
|
| 52 |
+
try:
|
| 53 |
+
with open(model_config_path, 'r') as f:
|
| 54 |
+
model_raw_config = yaml.safe_load(f)
|
| 55 |
+
except yaml.YAMLError as e:
|
| 56 |
+
raise ValueError(f"Failed to parse YAML file {model_config_path}: {e}")
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
model_config = ModelConfig(**model_raw_config)
|
| 60 |
+
self.logger.info(f"Successfully loaded model config from {model_config_path}")
|
| 61 |
+
except ValidationError as e:
|
| 62 |
+
raise ValueError(f"Model configuration validation for {model_config_path} failed:\n{e}")
|
| 63 |
|
| 64 |
return system_config, model_config
|
| 65 |
|
| 66 |
|
| 67 |
+
def load_config(system_config_path: Union[str, Path], model_config_path: Union[str, Path]) -> Tuple[SystemConfig, ModelConfig]:
|
| 68 |
"""Convenience function to load and validate config."""
|
| 69 |
config_loader = ConfigLoader()
|
| 70 |
return config_loader.load_and_validate(system_config_path, model_config_path)
|
src/config/schemas.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
from pydantic import BaseModel, Field, model_validator
|
| 2 |
-
from typing import Self, Tuple, List, Optional, Literal
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
|
@@ -40,59 +40,11 @@ class SystemConfig(BaseModel):
|
|
| 40 |
model_config = {"extra": "forbid"} # forbid extra fields
|
| 41 |
|
| 42 |
|
| 43 |
-
class
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
description="Type of model (fortitran or adafortitran)"
|
| 47 |
-
)
|
| 48 |
-
patch_size: Tuple[int, int] = Field(..., description="Patch size as (height, width)")
|
| 49 |
-
num_layers: int = Field(..., gt=0, description="Number of transformer layers")
|
| 50 |
-
model_dim: int = Field(..., gt=0, description="Model dimension")
|
| 51 |
-
num_head: int = Field(..., gt=0, description="Number of attention heads")
|
| 52 |
-
activation: Literal["relu", "gelu"] = Field(
|
| 53 |
-
default="gelu",
|
| 54 |
-
description="Activation function used within the transformer's FFN"
|
| 55 |
-
)
|
| 56 |
-
dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate used within the transformer's FFN")
|
| 57 |
-
max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
|
| 58 |
-
pos_encoding_type: Literal["learnable", "sinusoidal"] = Field(
|
| 59 |
-
default="learnable",
|
| 60 |
-
description="Positional encoding type"
|
| 61 |
-
)
|
| 62 |
-
adaptive_token_length: Optional[int] = Field(
|
| 63 |
-
default=None,
|
| 64 |
-
gt=0,
|
| 65 |
-
description="Adaptive token length (required for AdaFortiTran)"
|
| 66 |
-
)
|
| 67 |
-
channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
|
| 68 |
-
default=None,
|
| 69 |
-
description="Hidden sizes for channel adaptation layers (required for AdaFortiTran)"
|
| 70 |
-
)
|
| 71 |
device: str = Field(default="cpu", description="Device to use")
|
| 72 |
|
| 73 |
-
@model_validator(mode='after')
|
| 74 |
-
def validate_model_specific_requirements(self) -> Self:
|
| 75 |
-
"""Validate model-specific configuration requirements."""
|
| 76 |
-
if self.model_type == "adafortitran":
|
| 77 |
-
if self.channel_adaptivity_hidden_sizes is None:
|
| 78 |
-
raise ValueError(
|
| 79 |
-
"channel_adaptivity_hidden_sizes is required for AdaFortiTran model"
|
| 80 |
-
)
|
| 81 |
-
if self.adaptive_token_length is None:
|
| 82 |
-
raise ValueError(
|
| 83 |
-
"adaptive_token_length is required for AdaFortiTran model"
|
| 84 |
-
)
|
| 85 |
-
|
| 86 |
-
if self.model_type == "fortitran":
|
| 87 |
-
if self.channel_adaptivity_hidden_sizes is not None:
|
| 88 |
-
# Note: channel_adaptivity_hidden_sizes will be ignored for FortiTran
|
| 89 |
-
pass
|
| 90 |
-
if self.adaptive_token_length is not None:
|
| 91 |
-
# Note: adaptive_token_length will be ignored for FortiTran
|
| 92 |
-
pass
|
| 93 |
-
|
| 94 |
-
return self
|
| 95 |
-
|
| 96 |
@model_validator(mode='after')
|
| 97 |
def validate_device(self) -> Self:
|
| 98 |
"""Validate that the specified device is available."""
|
|
@@ -152,4 +104,59 @@ class ModelConfig(BaseModel):
|
|
| 152 |
f"Available devices: {available_devices}"
|
| 153 |
)
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
model_config = {"extra": "forbid"}
|
|
|
|
| 1 |
from pydantic import BaseModel, Field, model_validator
|
| 2 |
+
from typing import Self, Tuple, List, Optional, Literal, Union
|
| 3 |
import torch
|
| 4 |
|
| 5 |
|
|
|
|
| 40 |
model_config = {"extra": "forbid"} # forbid extra fields
|
| 41 |
|
| 42 |
|
| 43 |
+
class BaseConfig(BaseModel):
|
| 44 |
+
"""Base configuration class with device validation."""
|
| 45 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
device: str = Field(default="cpu", description="Device to use")
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
@model_validator(mode='after')
|
| 49 |
def validate_device(self) -> Self:
|
| 50 |
"""Validate that the specified device is available."""
|
|
|
|
| 104 |
f"Available devices: {available_devices}"
|
| 105 |
)
|
| 106 |
|
| 107 |
+
|
| 108 |
+
class ModelConfig(BaseConfig):
|
| 109 |
+
model_type: Literal["linear", "fortitran", "adafortitran"] = Field(
|
| 110 |
+
default="fortitran",
|
| 111 |
+
description="Type of model (linear, fortitran, or adafortitran)"
|
| 112 |
+
)
|
| 113 |
+
patch_size: Tuple[int, int] = Field(..., description="Patch size as (subcarriers_per_patch, symbols_per_patch)")
|
| 114 |
+
num_layers: int = Field(..., gt=0, description="Number of transformer layers")
|
| 115 |
+
model_dim: int = Field(..., gt=0, description="Model dimension")
|
| 116 |
+
num_head: int = Field(..., gt=0, description="Number of attention heads")
|
| 117 |
+
activation: Literal["relu", "gelu"] = Field(
|
| 118 |
+
default="gelu",
|
| 119 |
+
description="Activation function used within the transformer's FFN"
|
| 120 |
+
)
|
| 121 |
+
dropout: float = Field(default=0.1, ge=0.0, le=1.0, description="Dropout rate used within the transformer's FFN")
|
| 122 |
+
max_seq_len: int = Field(default=512, gt=0, description="Maximum sequence length")
|
| 123 |
+
pos_encoding_type: Literal["learnable", "sinusoidal"] = Field(
|
| 124 |
+
default="learnable",
|
| 125 |
+
description="Positional encoding type"
|
| 126 |
+
)
|
| 127 |
+
adaptive_token_length: Optional[int] = Field(
|
| 128 |
+
default=None,
|
| 129 |
+
gt=0,
|
| 130 |
+
description="Adaptive token length (required for AdaFortiTran)"
|
| 131 |
+
)
|
| 132 |
+
channel_adaptivity_hidden_sizes: Optional[List[int]] = Field(
|
| 133 |
+
default=None,
|
| 134 |
+
description="Hidden sizes for channel adaptation layers (required for AdaFortiTran)"
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
@model_validator(mode='after')
|
| 138 |
+
def validate_model_specific_requirements(self) -> Self:
|
| 139 |
+
"""Validate model-specific configuration requirements."""
|
| 140 |
+
if self.model_type == "linear":
|
| 141 |
+
# Linear model only needs device, no additional validation required
|
| 142 |
+
pass
|
| 143 |
+
elif self.model_type == "adafortitran":
|
| 144 |
+
if self.channel_adaptivity_hidden_sizes is None:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
"channel_adaptivity_hidden_sizes is required for AdaFortiTran model"
|
| 147 |
+
)
|
| 148 |
+
if self.adaptive_token_length is None:
|
| 149 |
+
raise ValueError(
|
| 150 |
+
"adaptive_token_length is required for AdaFortiTran model"
|
| 151 |
+
)
|
| 152 |
+
elif self.model_type == "fortitran":
|
| 153 |
+
if self.channel_adaptivity_hidden_sizes is not None:
|
| 154 |
+
# Note: channel_adaptivity_hidden_sizes will be ignored for FortiTran
|
| 155 |
+
pass
|
| 156 |
+
if self.adaptive_token_length is not None:
|
| 157 |
+
# Note: adaptive_token_length will be ignored for FortiTran
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
return self
|
| 161 |
+
|
| 162 |
model_config = {"extra": "forbid"}
|
src/main.py
CHANGED
|
@@ -44,6 +44,7 @@ from pathlib import Path
|
|
| 44 |
from src.main.parser import parse_arguments
|
| 45 |
from src.main.trainer import train
|
| 46 |
from src.config import load_config
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
def setup_logging(log_level: str) -> None:
|
|
@@ -88,10 +89,10 @@ def main() -> None:
|
|
| 88 |
logger.info("Configuration loaded successfully")
|
| 89 |
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
|
| 90 |
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
|
| 91 |
-
if model_config
|
| 92 |
-
logger.info(f"
|
| 93 |
else:
|
| 94 |
-
logger.info("
|
| 95 |
|
| 96 |
# Start training
|
| 97 |
logger.info("Initializing training...")
|
|
|
|
| 44 |
from src.main.parser import parse_arguments
|
| 45 |
from src.main.trainer import train
|
| 46 |
from src.config import load_config
|
| 47 |
+
from src.config.schemas import ModelConfig
|
| 48 |
|
| 49 |
|
| 50 |
def setup_logging(log_level: str) -> None:
|
|
|
|
| 89 |
logger.info("Configuration loaded successfully")
|
| 90 |
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
|
| 91 |
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
|
| 92 |
+
if model_config.model_type == "linear":
|
| 93 |
+
logger.info(f"Linear model with device: {model_config.device}")
|
| 94 |
else:
|
| 95 |
+
logger.info(f"Model architecture: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
|
| 96 |
|
| 97 |
# Start training
|
| 98 |
logger.info("Initializing training...")
|
src/main/trainer.py
CHANGED
|
@@ -69,7 +69,7 @@ class ModelTrainer:
|
|
| 69 |
|
| 70 |
EXP_LR_GAMMA = 0.995
|
| 71 |
|
| 72 |
-
def __init__(self, system_config: SystemConfig, model_config: ModelConfig
|
| 73 |
"""
|
| 74 |
Initialize the ModelTrainer.
|
| 75 |
|
|
@@ -121,14 +121,10 @@ class ModelTrainer:
|
|
| 121 |
Initialized model instance of the specified type
|
| 122 |
"""
|
| 123 |
if self.args.model_name == "linear":
|
| 124 |
-
model = LinearEstimator(self.system_config,
|
| 125 |
elif self.args.model_name == "adafortitran":
|
| 126 |
-
if self.model_config is None:
|
| 127 |
-
raise ValueError("model_config must be provided for AdaFortiTranEstimator.")
|
| 128 |
model = AdaFortiTranEstimator(self.system_config, self.model_config)
|
| 129 |
elif self.args.model_name == "fortitran":
|
| 130 |
-
if self.model_config is None:
|
| 131 |
-
raise ValueError("model_config must be provided for FortiTranEstimator.")
|
| 132 |
model = FortiTranEstimator(self.system_config, self.model_config)
|
| 133 |
else:
|
| 134 |
raise ValueError(f"Unknown model name: {self.args.model_name}")
|
|
@@ -406,7 +402,7 @@ class ModelTrainer:
|
|
| 406 |
self.writer.close()
|
| 407 |
|
| 408 |
|
| 409 |
-
def train(system_config: SystemConfig, model_config: ModelConfig
|
| 410 |
"""
|
| 411 |
Train an OFDM channel estimation model.
|
| 412 |
|
|
|
|
| 69 |
|
| 70 |
EXP_LR_GAMMA = 0.995
|
| 71 |
|
| 72 |
+
def __init__(self, system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments):
|
| 73 |
"""
|
| 74 |
Initialize the ModelTrainer.
|
| 75 |
|
|
|
|
| 121 |
Initialized model instance of the specified type
|
| 122 |
"""
|
| 123 |
if self.args.model_name == "linear":
|
| 124 |
+
model = LinearEstimator(self.system_config, self.model_config)
|
| 125 |
elif self.args.model_name == "adafortitran":
|
|
|
|
|
|
|
| 126 |
model = AdaFortiTranEstimator(self.system_config, self.model_config)
|
| 127 |
elif self.args.model_name == "fortitran":
|
|
|
|
|
|
|
| 128 |
model = FortiTranEstimator(self.system_config, self.model_config)
|
| 129 |
else:
|
| 130 |
raise ValueError(f"Unknown model name: {self.args.model_name}")
|
|
|
|
| 402 |
self.writer.close()
|
| 403 |
|
| 404 |
|
| 405 |
+
def train(system_config: SystemConfig, model_config: ModelConfig, args: TrainingArguments) -> None:
|
| 406 |
"""
|
| 407 |
Train an OFDM channel estimation model.
|
| 408 |
|
src/models/blocks/enhancers.py
CHANGED
|
@@ -23,9 +23,9 @@ class ConvEnhancer(nn.Module):
|
|
| 23 |
"""Forward pass through the convolutional enhancement network.
|
| 24 |
|
| 25 |
Args:
|
| 26 |
-
x (torch.Tensor): Input tensor of shape (batch_size, 1,
|
| 27 |
|
| 28 |
Returns:
|
| 29 |
-
torch.Tensor: Enhanced tensor of shape (batch_size, 1,
|
| 30 |
"""
|
| 31 |
return self.conv_block(x)
|
|
|
|
| 23 |
"""Forward pass through the convolutional enhancement network.
|
| 24 |
|
| 25 |
Args:
|
| 26 |
+
x (torch.Tensor): Input tensor of shape (batch_size, 1, num_subcarriers, num_symbols)
|
| 27 |
|
| 28 |
Returns:
|
| 29 |
+
torch.Tensor: Enhanced tensor of shape (batch_size, 1, num_subcarriers, num_symbols)
|
| 30 |
"""
|
| 31 |
return self.conv_block(x)
|
src/models/blocks/patch_processors.py
CHANGED
|
@@ -15,7 +15,7 @@ class PatchEmbedding(nn.Module):
|
|
| 15 |
"""Initialize the PatchEmbedding layer.
|
| 16 |
|
| 17 |
Args:
|
| 18 |
-
patch_size: Size of patches to extract (
|
| 19 |
"""
|
| 20 |
super().__init__()
|
| 21 |
self.patch_size = patch_size
|
|
@@ -25,11 +25,11 @@ class PatchEmbedding(nn.Module):
|
|
| 25 |
"""Transform input tensor into patch embeddings.
|
| 26 |
|
| 27 |
Args:
|
| 28 |
-
x: Input tensor of shape (batch_size,
|
| 29 |
|
| 30 |
Returns:
|
| 31 |
Tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1])
|
| 32 |
-
where num_patches = (
|
| 33 |
"""
|
| 34 |
x = self.unfold(torch.unsqueeze(x, dim=1))
|
| 35 |
return torch.permute(x, dims=(0, 2, 1))
|
|
@@ -46,8 +46,8 @@ class InversePatchEmbedding(nn.Module):
|
|
| 46 |
"""Initialize the InversePatchEmbedding layer.
|
| 47 |
|
| 48 |
Args:
|
| 49 |
-
output_size: Size of output matrix (
|
| 50 |
-
patch_size: Size of input patches (
|
| 51 |
"""
|
| 52 |
super().__init__()
|
| 53 |
self.fold = nn.Fold(
|
|
@@ -64,7 +64,7 @@ class InversePatchEmbedding(nn.Module):
|
|
| 64 |
where num_patches = (output_size[0] // patch_size[0]) * (output_size[1] // patch_size[1])
|
| 65 |
|
| 66 |
Returns:
|
| 67 |
-
Tensor of shape (batch_size,
|
| 68 |
"""
|
| 69 |
x = torch.permute(x, dims=(0, 2, 1))
|
| 70 |
x = self.fold(x)
|
|
|
|
| 15 |
"""Initialize the PatchEmbedding layer.
|
| 16 |
|
| 17 |
Args:
|
| 18 |
+
patch_size: Size of patches to extract (subcarriers_per_patch, symbols_per_patch)
|
| 19 |
"""
|
| 20 |
super().__init__()
|
| 21 |
self.patch_size = patch_size
|
|
|
|
| 25 |
"""Transform input tensor into patch embeddings.
|
| 26 |
|
| 27 |
Args:
|
| 28 |
+
x: Input tensor of shape (batch_size, num_subcarriers, num_symbols)
|
| 29 |
|
| 30 |
Returns:
|
| 31 |
Tensor of shape (batch_size, num_patches, patch_size[0]*patch_size[1])
|
| 32 |
+
where num_patches = (num_subcarriers // patch_size[0]) * (num_symbols // patch_size[1])
|
| 33 |
"""
|
| 34 |
x = self.unfold(torch.unsqueeze(x, dim=1))
|
| 35 |
return torch.permute(x, dims=(0, 2, 1))
|
|
|
|
| 46 |
"""Initialize the InversePatchEmbedding layer.
|
| 47 |
|
| 48 |
Args:
|
| 49 |
+
output_size: Size of output matrix (num_subcarriers, num_symbols)
|
| 50 |
+
patch_size: Size of input patches (subcarriers_per_patch, symbols_per_patch)
|
| 51 |
"""
|
| 52 |
super().__init__()
|
| 53 |
self.fold = nn.Fold(
|
|
|
|
| 64 |
where num_patches = (output_size[0] // patch_size[0]) * (output_size[1] // patch_size[1])
|
| 65 |
|
| 66 |
Returns:
|
| 67 |
+
Tensor of shape (batch_size, num_subcarriers, num_symbols)
|
| 68 |
"""
|
| 69 |
x = torch.permute(x, dims=(0, 2, 1))
|
| 70 |
x = self.fold(x)
|
src/models/fortitran.py
CHANGED
|
@@ -4,8 +4,7 @@ import logging
|
|
| 4 |
from typing import Tuple, List, Optional
|
| 5 |
|
| 6 |
from src.config.schemas import SystemConfig, ModelConfig
|
| 7 |
-
from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels,
|
| 8 |
-
ChannelAdapter
|
| 9 |
|
| 10 |
|
| 11 |
class BaseFortiTranEstimator(nn.Module):
|
|
@@ -13,11 +12,11 @@ class BaseFortiTranEstimator(nn.Module):
|
|
| 13 |
Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
|
| 14 |
|
| 15 |
This model performs channel estimation by:
|
| 16 |
-
1. Upsampling pilot symbols to full OFDM grid size
|
| 17 |
-
2. Applying convolutional enhancement for
|
| 18 |
3. Converting to patch embeddings for transformer processing
|
| 19 |
4. Using transformer encoder to capture long-range dependencies
|
| 20 |
-
5. Reconstructing
|
| 21 |
6. Final convolutional refinement for high-quality channel estimates
|
| 22 |
"""
|
| 23 |
|
|
@@ -29,7 +28,7 @@ class BaseFortiTranEstimator(nn.Module):
|
|
| 29 |
Args:
|
| 30 |
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
|
| 31 |
model_config: Model architecture configuration (patch size, layers, etc.)
|
| 32 |
-
use_channel_adaptation: Whether to enable channel adaptation features
|
| 33 |
"""
|
| 34 |
super().__init__()
|
| 35 |
|
|
@@ -73,11 +72,13 @@ class BaseFortiTranEstimator(nn.Module):
|
|
| 73 |
self.model_config.patch_size[0] * self.model_config.patch_size[1]
|
| 74 |
)
|
| 75 |
|
| 76 |
-
#
|
| 77 |
if self.use_channel_adaptation:
|
| 78 |
-
|
|
|
|
|
|
|
| 79 |
else:
|
| 80 |
-
self.
|
| 81 |
|
| 82 |
def _build_architecture(self) -> None:
|
| 83 |
"""Construct the model architecture components."""
|
|
@@ -92,14 +93,19 @@ class BaseFortiTranEstimator(nn.Module):
|
|
| 92 |
|
| 93 |
# 4. Channel adapter (conditional on use_channel_adaptation)
|
| 94 |
if self.use_channel_adaptation:
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
# 5. Transformer encoder for sequence modeling
|
| 98 |
-
transformer_input_dim = self.adaptive_patch_length if self.use_channel_adaptation else self.patch_length
|
| 99 |
transformer_output_dim = self.patch_length # Always output standard patch length
|
| 100 |
|
| 101 |
self.transformer_encoder = TransformerEncoderForChannels(
|
| 102 |
-
input_dim=transformer_input_dim,
|
| 103 |
output_dim=transformer_output_dim,
|
| 104 |
model_dim=self.model_config.model_dim,
|
| 105 |
num_head=self.model_config.num_head,
|
|
@@ -189,7 +195,7 @@ class BaseFortiTranEstimator(nn.Module):
|
|
| 189 |
"""
|
| 190 |
batch_size = x.shape[0]
|
| 191 |
|
| 192 |
-
# Flatten
|
| 193 |
if x.dim() > 2:
|
| 194 |
x = x.view(batch_size, -1)
|
| 195 |
|
|
@@ -215,7 +221,7 @@ class BaseFortiTranEstimator(nn.Module):
|
|
| 215 |
# Stage 5: Transformer processing for long-range dependencies
|
| 216 |
transformer_output = self.transformer_encoder(transformer_input)
|
| 217 |
|
| 218 |
-
# Stage 6: Reconstruct
|
| 219 |
reconstructed = self.patch_reconstructor(transformer_output)
|
| 220 |
|
| 221 |
# Stage 7: Apply residual connection
|
|
@@ -235,7 +241,7 @@ class BaseFortiTranEstimator(nn.Module):
|
|
| 235 |
'pilot_size': self.pilot_size,
|
| 236 |
'patch_size': self.model_config.patch_size,
|
| 237 |
'patch_length': self.patch_length,
|
| 238 |
-
'
|
| 239 |
'model_dim': self.model_config.model_dim,
|
| 240 |
'num_layers': self.model_config.num_layers,
|
| 241 |
'device': str(self.device),
|
|
|
|
| 4 |
from typing import Tuple, List, Optional
|
| 5 |
|
| 6 |
from src.config.schemas import SystemConfig, ModelConfig
|
| 7 |
+
from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, ChannelAdapter
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
class BaseFortiTranEstimator(nn.Module):
|
|
|
|
| 12 |
Base Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
|
| 13 |
|
| 14 |
This model performs channel estimation by:
|
| 15 |
+
1. Upsampling pilot symbols to full OFDM grid size (with linear layer)
|
| 16 |
+
2. Applying convolutional enhancement for subcarrier-symbol features
|
| 17 |
3. Converting to patch embeddings for transformer processing
|
| 18 |
4. Using transformer encoder to capture long-range dependencies
|
| 19 |
+
5. Reconstructing subcarrier-symbol representation and applying residual connections
|
| 20 |
6. Final convolutional refinement for high-quality channel estimates
|
| 21 |
"""
|
| 22 |
|
|
|
|
| 28 |
Args:
|
| 29 |
system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
|
| 30 |
model_config: Model architecture configuration (patch size, layers, etc.)
|
| 31 |
+
use_channel_adaptation: Whether to enable channel adaptation features (disabled for FortiTran)
|
| 32 |
"""
|
| 33 |
super().__init__()
|
| 34 |
|
|
|
|
| 72 |
self.model_config.patch_size[0] * self.model_config.patch_size[1]
|
| 73 |
)
|
| 74 |
|
| 75 |
+
# Transformer input dimension (includes channel tokens if adaptation is enabled)
|
| 76 |
if self.use_channel_adaptation:
|
| 77 |
+
if self.model_config.adaptive_token_length is None:
|
| 78 |
+
raise ValueError("adaptive_token_length must be set when channel adaptation is enabled")
|
| 79 |
+
self.transformer_input_dim = self.patch_length + self.model_config.adaptive_token_length
|
| 80 |
else:
|
| 81 |
+
self.transformer_input_dim = self.patch_length
|
| 82 |
|
| 83 |
def _build_architecture(self) -> None:
|
| 84 |
"""Construct the model architecture components."""
|
|
|
|
| 93 |
|
| 94 |
# 4. Channel adapter (conditional on use_channel_adaptation)
|
| 95 |
if self.use_channel_adaptation:
|
| 96 |
+
if self.model_config.channel_adaptivity_hidden_sizes is None:
|
| 97 |
+
raise ValueError("channel_adaptivity_hidden_sizes must be set when channel adaptation is enabled")
|
| 98 |
+
# Convert list to tuple as expected by ChannelAdapter (exactly 3 values)
|
| 99 |
+
hidden_sizes = tuple(self.model_config.channel_adaptivity_hidden_sizes)
|
| 100 |
+
if len(hidden_sizes) != 3:
|
| 101 |
+
raise ValueError("channel_adaptivity_hidden_sizes must have exactly 3 values")
|
| 102 |
+
self.channel_adapter = ChannelAdapter(hidden_sizes)
|
| 103 |
|
| 104 |
# 5. Transformer encoder for sequence modeling
|
|
|
|
| 105 |
transformer_output_dim = self.patch_length # Always output standard patch length
|
| 106 |
|
| 107 |
self.transformer_encoder = TransformerEncoderForChannels(
|
| 108 |
+
input_dim=self.transformer_input_dim,
|
| 109 |
output_dim=transformer_output_dim,
|
| 110 |
model_dim=self.model_config.model_dim,
|
| 111 |
num_head=self.model_config.num_head,
|
|
|
|
| 195 |
"""
|
| 196 |
batch_size = x.shape[0]
|
| 197 |
|
| 198 |
+
# Flatten subcarrier and symbol dimensions for linear upsampling
|
| 199 |
if x.dim() > 2:
|
| 200 |
x = x.view(batch_size, -1)
|
| 201 |
|
|
|
|
| 221 |
# Stage 5: Transformer processing for long-range dependencies
|
| 222 |
transformer_output = self.transformer_encoder(transformer_input)
|
| 223 |
|
| 224 |
+
# Stage 6: Reconstruct subcarrier-symbol representation
|
| 225 |
reconstructed = self.patch_reconstructor(transformer_output)
|
| 226 |
|
| 227 |
# Stage 7: Apply residual connection
|
|
|
|
| 241 |
'pilot_size': self.pilot_size,
|
| 242 |
'patch_size': self.model_config.patch_size,
|
| 243 |
'patch_length': self.patch_length,
|
| 244 |
+
'transformer_input_dim': self.transformer_input_dim,
|
| 245 |
'model_dim': self.model_config.model_dim,
|
| 246 |
'num_layers': self.model_config.num_layers,
|
| 247 |
'device': str(self.device),
|
src/models/linear.py
CHANGED
|
@@ -10,43 +10,47 @@ import logging
|
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
| 12 |
|
| 13 |
-
from src.config.schemas import SystemConfig
|
| 14 |
|
| 15 |
|
| 16 |
class LinearEstimator(nn.Module):
|
| 17 |
"""Learned MMSE estimator.
|
| 18 |
|
|
|
|
|
|
|
| 19 |
Attributes:
|
| 20 |
device (torch.device): Target device for computation
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
|
|
|
| 28 |
"""
|
| 29 |
|
| 30 |
-
def __init__(self,
|
| 31 |
"""Initialize the MMSE estimator.
|
| 32 |
|
| 33 |
Args:
|
| 34 |
-
|
| 35 |
-
|
| 36 |
"""
|
| 37 |
super().__init__()
|
| 38 |
|
| 39 |
-
self.
|
| 40 |
-
self.
|
|
|
|
| 41 |
self.logger = logging.getLogger(__name__)
|
| 42 |
|
| 43 |
# Extract dimensions from validated config
|
| 44 |
-
self.ofdm_size = (
|
| 45 |
-
self.pilot_size = (
|
| 46 |
|
| 47 |
# Calculate feature dimensions
|
| 48 |
-
in_feature_dim =
|
| 49 |
-
out_feature_dim =
|
| 50 |
|
| 51 |
self.logger.info(f"Initializing LinearEstimator:")
|
| 52 |
self.logger.info(f" OFDM size: {self.ofdm_size}")
|
|
@@ -70,7 +74,7 @@ class LinearEstimator(nn.Module):
|
|
| 70 |
Estimated OFDM signal tensor with shape
|
| 71 |
(batch_size, ofdm_size[0], ofdm_size[1])
|
| 72 |
"""
|
| 73 |
-
# pytorch does
|
| 74 |
x = x.to(self.device)
|
| 75 |
self.logger.debug(f"Input shape: {x.size()}")
|
| 76 |
|
|
@@ -95,14 +99,6 @@ class LinearEstimator(nn.Module):
|
|
| 95 |
|
| 96 |
return x
|
| 97 |
|
| 98 |
-
def get_config(self) -> SystemConfig:
|
| 99 |
-
"""Get the configuration used by this estimator.
|
| 100 |
-
|
| 101 |
-
Returns:
|
| 102 |
-
SystemConfig: The configuration object
|
| 103 |
-
"""
|
| 104 |
-
return self.config
|
| 105 |
-
|
| 106 |
def __repr__(self) -> str:
|
| 107 |
"""String representation of the estimator."""
|
| 108 |
return (
|
|
|
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
| 12 |
|
| 13 |
+
from src.config.schemas import SystemConfig, ModelConfig
|
| 14 |
|
| 15 |
|
| 16 |
class LinearEstimator(nn.Module):
|
| 17 |
"""Learned MMSE estimator.
|
| 18 |
|
| 19 |
+
Find W such that W*h_pilot = h_hat, where h_hat is the estimated channel by stochastic gradient descent on |h_hat - h_ideal|^2
|
| 20 |
+
|
| 21 |
Attributes:
|
| 22 |
device (torch.device): Target device for computation
|
| 23 |
+
system_config (SystemConfig): Validated configuration object for OFDM system parameters
|
| 24 |
+
model_config (ModelConfig): Validated configuration object for model parameters
|
| 25 |
+
ofdm_size (Tuple[int, int]): Dimensions of OFDM frame as (num_subcarriers, num_symbols)
|
| 26 |
+
num_subcarriers (int): number of sub-carriers
|
| 27 |
+
num_symbols (int): number of OFDM symbols
|
| 28 |
+
pilot_size (Tuple[int, int]): Dimensions of pilot signal as (num_subcarriers, num_symbols)
|
| 29 |
+
num_subcarriers (int): number of pilots across sub-carriers
|
| 30 |
+
num_symbols (int): number of pilots across OFDM symbols
|
| 31 |
"""
|
| 32 |
|
| 33 |
+
def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
|
| 34 |
"""Initialize the MMSE estimator.
|
| 35 |
|
| 36 |
Args:
|
| 37 |
+
system_config: Validated SystemConfig object containing OFDM system parameters
|
| 38 |
+
model_config: Validated ModelConfig object containing model parameters
|
| 39 |
"""
|
| 40 |
super().__init__()
|
| 41 |
|
| 42 |
+
self.system_config = system_config
|
| 43 |
+
self.model_config = model_config
|
| 44 |
+
self.device = torch.device(model_config.device)
|
| 45 |
self.logger = logging.getLogger(__name__)
|
| 46 |
|
| 47 |
# Extract dimensions from validated config
|
| 48 |
+
self.ofdm_size = (system_config.ofdm.num_scs, system_config.ofdm.num_symbols)
|
| 49 |
+
self.pilot_size = (system_config.pilot.num_scs, system_config.pilot.num_symbols)
|
| 50 |
|
| 51 |
# Calculate feature dimensions
|
| 52 |
+
in_feature_dim = system_config.pilot.num_scs * system_config.pilot.num_symbols
|
| 53 |
+
out_feature_dim = system_config.ofdm.num_scs * system_config.ofdm.num_symbols
|
| 54 |
|
| 55 |
self.logger.info(f"Initializing LinearEstimator:")
|
| 56 |
self.logger.info(f" OFDM size: {self.ofdm_size}")
|
|
|
|
| 74 |
Estimated OFDM signal tensor with shape
|
| 75 |
(batch_size, ofdm_size[0], ofdm_size[1])
|
| 76 |
"""
|
| 77 |
+
# pytorch does nothing if input is already on correct device
|
| 78 |
x = x.to(self.device)
|
| 79 |
self.logger.debug(f"Input shape: {x.size()}")
|
| 80 |
|
|
|
|
| 99 |
|
| 100 |
return x
|
| 101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
def __repr__(self) -> str:
|
| 103 |
"""String representation of the estimator."""
|
| 104 |
return (
|