BerkIGuler commited on
Commit
facb3d5
·
1 Parent(s): cbe30e6

added fortitran estimator

Browse files
config/model_config.yaml CHANGED
@@ -1,3 +1,9 @@
1
- patch_size: [10, 4]
2
  num_layers: 6
3
  device: "cpu"
 
 
 
 
 
 
 
1
+ patch_size: [3, 2]
2
  num_layers: 6
3
  device: "cpu"
4
+ model_dim: 128
5
+ num_head: 4
6
+ activation: 'gelu'
7
+ dropout: 0.1
8
+ max_seq_len: 512
9
+ pos_encoding_type: 'learnable'
src/config/schemas.py CHANGED
@@ -1,5 +1,6 @@
1
  from pydantic import BaseModel, Field, model_validator
2
  from typing import Self, Tuple
 
3
 
4
 
5
  class OFDMParams(BaseModel):
@@ -19,10 +20,63 @@ class ModelParams(BaseModel):
19
 
20
  @model_validator(mode='after')
21
  def validate_device(self) -> Self:
22
- pass
23
-
24
-
25
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
 
28
  class SystemConfig(BaseModel):
@@ -84,4 +138,4 @@ class ModelConfig(BaseModel):
84
 
85
  return self
86
 
87
- model_config = {"extra": "forbid"}
 
1
  from pydantic import BaseModel, Field, model_validator
2
  from typing import Self, Tuple
3
+ import torch
4
 
5
 
6
  class OFDMParams(BaseModel):
 
20
 
21
  @model_validator(mode='after')
22
  def validate_device(self) -> Self:
23
+ """Validate that the specified device is available."""
24
+ device_str = self.device.lower()
25
+
26
+ # Handle 'auto' case - automatically select best available device
27
+ if device_str == 'auto':
28
+ if torch.cuda.is_available():
29
+ self.device = 'cuda'
30
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
31
+ self.device = 'mps' # Apple Silicon
32
+ else:
33
+ self.device = 'cpu'
34
+ return self
35
+
36
+ # Validate CPU
37
+ if device_str == 'cpu':
38
+ return self
39
+
40
+ # Validate CUDA devices
41
+ if device_str.startswith('cuda'):
42
+ if not torch.cuda.is_available():
43
+ raise ValueError("CUDA is not available on this system")
44
+
45
+ # Handle specific CUDA device (e.g., 'cuda:0', 'cuda:1')
46
+ if ':' in device_str:
47
+ try:
48
+ device_id = int(device_str.split(':')[1])
49
+ if device_id >= torch.cuda.device_count():
50
+ available_devices = list(range(torch.cuda.device_count()))
51
+ raise ValueError(
52
+ f"CUDA device {device_id} not available. "
53
+ f"Available CUDA devices: {available_devices}"
54
+ )
55
+ except (ValueError, IndexError) as e:
56
+ if "invalid literal" in str(e):
57
+ raise ValueError(f"Invalid CUDA device format: {device_str}")
58
+ raise
59
+
60
+ return self
61
+
62
+ # Validate MPS (Apple Silicon)
63
+ if device_str == 'mps':
64
+ if not (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()):
65
+ raise ValueError("MPS is not available on this system")
66
+ return self
67
+
68
+ # If we get here, the device is not recognized
69
+ available_devices = ['cpu']
70
+ if torch.cuda.is_available():
71
+ cuda_devices = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
72
+ available_devices.extend(['cuda'] + cuda_devices)
73
+ if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
74
+ available_devices.append('mps')
75
+
76
+ raise ValueError(
77
+ f"Unsupported device: '{self.device}'. "
78
+ f"Available devices: {available_devices}"
79
+ )
80
 
81
 
82
  class SystemConfig(BaseModel):
 
138
 
139
  return self
140
 
141
+ model_config = {"extra": "forbid"}
src/models/blocks/__init__.py CHANGED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from channel_adaptivity import ChannelAdapter
2
+ from encoders import TransformerEncoderForChannels
3
+ from enhancers import ConvEnhancer
4
+ from patch_processors import PatchEmbedding, InversePatchEmbedding
5
+ from positional_encodings import SinusoidalPositionalEncoding, LearnablePositionalEncoding
src/models/blocks/patch_processors.py CHANGED
@@ -40,8 +40,8 @@ class InversePatchEmbedding(nn.Module):
40
 
41
  def __init__(
42
  self,
43
- output_size: Tuple[int, int] = (120, 28),
44
- patch_size: Tuple[int, int] = (10, 4)
45
  ):
46
  """Initialize the InversePatchEmbedding layer.
47
 
 
40
 
41
  def __init__(
42
  self,
43
+ output_size: Tuple[int, int] = (120, 14),
44
+ patch_size: Tuple[int, int] = (3, 2)
45
  ):
46
  """Initialize the InversePatchEmbedding layer.
47
 
src/models/fortitran.py CHANGED
@@ -1,40 +1,195 @@
1
- from torch import nn
2
  import torch
 
3
  import logging
4
 
5
  from src.config.schemas import SystemConfig, ModelConfig
 
6
 
7
 
8
  class FortiTranEstimator(nn.Module):
9
- """A DL-based Channel Estimator based on a hybrid convolutional + transformers model"""
 
 
 
 
 
 
 
 
 
 
 
10
  def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
11
- """Initialize the FortiTranEstimator.
 
12
 
13
  Args:
14
- system_config: SystemConfig object containing OFDM system parameters
15
- system_config: ModelConfig object containing model parameters
16
  """
17
  super().__init__()
18
 
19
  self.system_config = system_config
20
- self.device = torch.device(config.device)
21
- self.logger = logging.getLogger(__name__)
22
-
23
- # Extract dimensions from validated config
24
- self.ofdm_size = (config.ofdm.num_scs, config.ofdm.num_symbols)
25
- self.pilot_size = (config.pilot.num_scs, config.pilot.num_symbols)
26
-
27
- # Calculate feature dimensions
28
- in_feature_dim = config.pilot.num_scs * config.pilot.num_symbols
29
- out_feature_dim = config.ofdm.num_scs * config.ofdm.num_symbols
30
-
31
- self.logger.info(f"Initializing LinearEstimator:")
32
- self.logger.info(f" OFDM size: {self.ofdm_size}")
33
- self.logger.info(f" Pilot size: {self.pilot_size}")
34
- self.logger.info(f" Input features: {in_feature_dim}")
35
- self.logger.info(f" Output features: {out_feature_dim}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  self.logger.info(f" Device: {self.device}")
37
 
38
- # Create linear layer
39
- self.linear = nn.Linear(in_feature_dim, out_feature_dim)
40
- self.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torch import nn
3
  import logging
4
 
5
  from src.config.schemas import SystemConfig, ModelConfig
6
+ from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels
7
 
8
 
9
  class FortiTranEstimator(nn.Module):
10
+ """
11
+ Hybrid CNN-Transformer Channel Estimator for OFDM Systems.
12
+
13
+ This model performs channel estimation by:
14
+ 1. Upsampling pilot symbols to full OFDM grid size
15
+ 2. Applying convolutional enhancement for spatial features
16
+ 3. Converting to patch embeddings for transformer processing
17
+ 4. Using transformer encoder to capture long-range dependencies
18
+ 5. Reconstructing spatial representation and applying residual connections
19
+ 6. Final convolutional refinement for high-quality channel estimates
20
+ """
21
+
22
  def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
23
+ """
24
+ Initialize the FortiTranEstimator.
25
 
26
  Args:
27
+ system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
28
+ model_config: Model architecture configuration (patch size, layers, etc.)
29
  """
30
  super().__init__()
31
 
32
  self.system_config = system_config
33
+ self.model_config = model_config
34
+ self.device = torch.device(model_config.device)
35
+ self.logger = logging.getLogger(self.__class__.__name__)
36
+
37
+ # Cache key dimensions for efficiency
38
+ self._setup_dimensions()
39
+
40
+ # Initialize model components
41
+ self._build_architecture()
42
+
43
+ # Move model to specified device
44
+ self.to(self.device)
45
+
46
+ self._log_initialization_info()
47
+
48
+ def _setup_dimensions(self) -> None:
49
+ """Calculate and cache key dimensions from configuration."""
50
+ # OFDM grid dimensions
51
+ self.ofdm_size = (
52
+ self.system_config.ofdm.num_scs,
53
+ self.system_config.ofdm.num_symbols
54
+ )
55
+
56
+ # Pilot arrangement dimensions
57
+ self.pilot_size = (
58
+ self.system_config.pilot.num_scs,
59
+ self.system_config.pilot.num_symbols
60
+ )
61
+
62
+ # Feature dimensions for linear layers
63
+ self.pilot_features = self.pilot_size[0] * self.pilot_size[1]
64
+ self.ofdm_features = self.ofdm_size[0] * self.ofdm_size[1]
65
+
66
+ # Patch processing dimensions
67
+ self.patch_length = (
68
+ self.model_config.patch_size[0] * self.model_config.patch_size[1]
69
+ )
70
+
71
+ def _build_architecture(self) -> None:
72
+ """Construct the model architecture components."""
73
+ # 1. Pilot-to-OFDM upsampling
74
+ self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features)
75
+ # 2. Initial convolutional enhancement
76
+ self.initial_enhancer = ConvEnhancer()
77
+
78
+ # 3. Patch embedding for transformer processing
79
+ self.patch_embedder = PatchEmbedding(self.model_config.patch_size)
80
+
81
+ # 4. Transformer encoder for sequence modeling
82
+ self.transformer_encoder = TransformerEncoderForChannels(
83
+ input_dim=self.patch_length,
84
+ output_dim=self.patch_length,
85
+ model_dim=self.model_config.model_dim,
86
+ num_head=self.model_config.num_head,
87
+ activation=self.model_config.activation,
88
+ dropout=self.model_config.dropout,
89
+ num_layers=self.model_config.num_layers,
90
+ max_len=self.model_config.max_seq_len,
91
+ pos_encoding_type=self.model_config.pos_encoding_type,
92
+ )
93
+
94
+ # 5. Patch reconstruction
95
+ self.patch_reconstructor = InversePatchEmbedding(
96
+ self.ofdm_size,
97
+ self.model_config.patch_size
98
+ )
99
+
100
+ # 6. Final convolutional refinement
101
+ self.final_refiner = ConvEnhancer()
102
+
103
+ def _log_initialization_info(self) -> None:
104
+ """Log model initialization details."""
105
+ self.logger.info("FortiTranEstimator initialized successfully:")
106
+ self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements")
107
+ self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements")
108
+ self.logger.info(f" Patch size: {self.model_config.patch_size}")
109
+ self.logger.info(f" Model dimension: {self.model_config.model_dim}")
110
+ self.logger.info(f" Transformer layers: {self.model_config.num_layers}")
111
  self.logger.info(f" Device: {self.device}")
112
 
113
+ total_params = sum(p.numel() for p in self.parameters())
114
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
115
+ self.logger.info(f" Total parameters: {total_params:,}")
116
+ self.logger.info(f" Trainable parameters: {trainable_params:,}")
117
+
118
+ def forward(self, pilot_symbols: torch.Tensor) -> torch.Tensor:
119
+ """
120
+ Forward pass for channel estimation.
121
+
122
+ Args:
123
+ pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols]
124
+
125
+ Returns:
126
+ Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols]
127
+ """
128
+ # Ensure input is on correct device
129
+ pilot_symbols = pilot_symbols.to(self.device)
130
+
131
+ # Process real and imaginary parts separately
132
+ real_estimate = self._forward_real_valued(pilot_symbols.real)
133
+ imag_estimate = self._forward_real_valued(pilot_symbols.imag)
134
+
135
+ # Combine into complex tensor
136
+ channel_estimate = torch.complex(real_estimate, imag_estimate)
137
+
138
+ return channel_estimate
139
+
140
+ def _forward_real_valued(self, x: torch.Tensor) -> torch.Tensor:
141
+ """
142
+ Process real-valued input through the estimation pipeline.
143
+
144
+ Args:
145
+ x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols]
146
+
147
+ Returns:
148
+ Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols]
149
+ """
150
+ batch_size = x.shape[0]
151
+
152
+ # Flatten spatial dimensions for linear upsampling
153
+ if x.dim() > 2:
154
+ x = x.view(batch_size, -1)
155
+
156
+ # Stage 1: Upsample from pilot grid to OFDM grid
157
+ upsampled = self.pilot_upsampler(x)
158
+
159
+ # Reshape for convolutional processing
160
+ upsampled_2d = upsampled.view(batch_size, 1, *self.ofdm_size)
161
+
162
+ # Stage 2: Initial convolutional enhancement
163
+ conv_enhanced = torch.squeeze(self.initial_enhancer(upsampled_2d), dim=1)
164
+
165
+ # Stage 3: Convert to patch embeddings
166
+ patch_embeddings = self.patch_embedder(conv_enhanced)
167
+
168
+ # Stage 4: Transformer processing for long-range dependencies
169
+ transformer_output = self.transformer_encoder(patch_embeddings)
170
+
171
+ # Stage 5: Reconstruct spatial representation
172
+ reconstructed = self.patch_reconstructor(transformer_output)
173
+
174
+ # Stage 6: Apply residual connection
175
+ residual_combined = conv_enhanced + reconstructed
176
+
177
+ # Stage 7: Final convolutional refinement
178
+ refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1)
179
+
180
+ return refined_output
181
+
182
+ def get_model_info(self) -> dict:
183
+ """Return model configuration and statistics."""
184
+ return {
185
+ 'model_name': self.__class__.__name__,
186
+ 'ofdm_size': self.ofdm_size,
187
+ 'pilot_size': self.pilot_size,
188
+ 'patch_size': self.model_config.patch_size,
189
+ 'patch_length': self.patch_length,
190
+ 'model_dim': self.model_config.model_dim,
191
+ 'num_layers': self.model_config.num_layers,
192
+ 'device': str(self.device),
193
+ 'total_parameters': sum(p.numel() for p in self.parameters()),
194
+ 'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
195
+ }