BerkIGuler commited on
Commit
b57f75f
·
1 Parent(s): facb3d5

added adafortitran estimator

Browse files
config/model_config.yaml CHANGED
@@ -7,3 +7,5 @@ activation: 'gelu'
7
  dropout: 0.1
8
  max_seq_len: 512
9
  pos_encoding_type: 'learnable'
 
 
 
7
  dropout: 0.1
8
  max_seq_len: 512
9
  pos_encoding_type: 'learnable'
10
+ channel_adaptivity_hidden_sizes: [7, 42, 560]
11
+ adaptive_token_length: 6
src/models/adafortitran.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import logging
4
+ from typing import Tuple, List
5
+
6
+ from src.config.schemas import SystemConfig, ModelConfig
7
+ from src.models.blocks import ConvEnhancer, PatchEmbedding, InversePatchEmbedding, TransformerEncoderForChannels, ChannelAdapter
8
+
9
+
10
+ class AdaFortiTranEstimator(nn.Module):
11
+
12
+ """
13
+ Hybrid CNN-Transformer Channel Estimator for OFDM Systems with channel adaptation.
14
+
15
+ This model performs channel estimation by:
16
+ 1. Upsampling pilot symbols to full OFDM grid size
17
+ 2. Applying convolutional enhancement for spatial features
18
+ 3. Converting to patch embeddings for transformer processing
19
+ 4. Concatenating channel statistics priors to channel patches
20
+ 5. Using transformer encoder to capture long-range dependencies
21
+ 6. Reconstructing spatial representation and applying residual connections
22
+ 7. Final convolutional refinement for high-quality channel estimates
23
+ """
24
+
25
+ def __init__(self, system_config: SystemConfig, model_config: ModelConfig) -> None:
26
+ """
27
+ Initialize the AdaFortiTranEstimator.
28
+
29
+ Args:
30
+ system_config: OFDM system configuration (subcarriers, symbols, pilot arrangement)
31
+ model_config: Model architecture configuration (patch size, layers, etc.)
32
+ """
33
+ super().__init__()
34
+
35
+ self.system_config = system_config
36
+ self.model_config = model_config
37
+ self.device = torch.device(model_config.device)
38
+ self.logger = logging.getLogger(self.__class__.__name__)
39
+
40
+ # Cache key dimensions for efficiency
41
+ self._setup_dimensions()
42
+
43
+ # Initialize model components
44
+ self._build_architecture()
45
+
46
+ # Move model to specified device
47
+ self.to(self.device)
48
+
49
+ self._log_initialization_info()
50
+
51
+ def _setup_dimensions(self) -> None:
52
+ """Calculate and cache key dimensions from configuration."""
53
+ # OFDM grid dimensions
54
+ self.ofdm_size = (
55
+ self.system_config.ofdm.num_scs,
56
+ self.system_config.ofdm.num_symbols
57
+ )
58
+
59
+ # Pilot arrangement dimensions
60
+ self.pilot_size = (
61
+ self.system_config.pilot.num_scs,
62
+ self.system_config.pilot.num_symbols
63
+ )
64
+
65
+ # Feature dimensions for linear layers
66
+ self.pilot_features = self.pilot_size[0] * self.pilot_size[1]
67
+ self.ofdm_features = self.ofdm_size[0] * self.ofdm_size[1]
68
+
69
+ # Patch processing dimensions
70
+ self.patch_length = (
71
+ self.model_config.patch_size[0] * self.model_config.patch_size[1]
72
+ )
73
+
74
+ self.adaptive_patch_length = self.patch_length + self.model_config.adaptive_token_length
75
+
76
+ def _build_architecture(self) -> None:
77
+ """Construct the model architecture components."""
78
+ # 1. Pilot-to-OFDM upsampling
79
+ self.pilot_upsampler = nn.Linear(self.pilot_features, self.ofdm_features)
80
+ # 2. Initial convolutional enhancement
81
+ self.initial_enhancer = ConvEnhancer()
82
+
83
+ # 3. Patch embedding for transformer processing
84
+ self.patch_embedder = PatchEmbedding(self.model_config.patch_size)
85
+
86
+ # 4. Channel adapter for conditional attention
87
+ self.channel_adapter = ChannelAdapter(self.model_config.channel_adaptivity_hidden_sizes)
88
+
89
+ # 5. Transformer encoder for sequence modeling
90
+ self.transformer_encoder = TransformerEncoderForChannels(
91
+ input_dim=self.adaptive_patch_length,
92
+ output_dim=self.patch_length,
93
+ model_dim=self.model_config.model_dim,
94
+ num_head=self.model_config.num_head,
95
+ activation=self.model_config.activation,
96
+ dropout=self.model_config.dropout,
97
+ num_layers=self.model_config.num_layers,
98
+ max_len=self.model_config.max_seq_len,
99
+ pos_encoding_type=self.model_config.pos_encoding_type
100
+ )
101
+
102
+ # 6. Patch reconstruction
103
+ self.patch_reconstructor = InversePatchEmbedding(
104
+ self.ofdm_size,
105
+ self.model_config.patch_size
106
+ )
107
+
108
+ # 7. Final convolutional refinement
109
+ self.final_refiner = ConvEnhancer()
110
+
111
+ def _log_initialization_info(self) -> None:
112
+ """Log model initialization details."""
113
+ self.logger.info("AdaFortiTranEstimator initialized successfully:")
114
+ self.logger.info(f" OFDM grid: {self.ofdm_size[0]}×{self.ofdm_size[1]} = {self.ofdm_features} elements")
115
+ self.logger.info(f" Pilot grid: {self.pilot_size[0]}×{self.pilot_size[1]} = {self.pilot_features} elements")
116
+ self.logger.info(f" Patch size: {self.model_config.patch_size}")
117
+ self.logger.info(f" Model dimension: {self.model_config.model_dim}")
118
+ self.logger.info(f" Transformer layers: {self.model_config.num_layers}")
119
+ self.logger.info(f" Device: {self.device}")
120
+
121
+ total_params = sum(p.numel() for p in self.parameters())
122
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
123
+ self.logger.info(f" Total parameters: {total_params:,}")
124
+ self.logger.info(f" Trainable parameters: {trainable_params:,}")
125
+
126
+ def forward(self, pilot_symbols: torch.Tensor, meta_data: Tuple) -> torch.Tensor:
127
+ """
128
+ Forward pass for channel estimation.
129
+
130
+ Args:
131
+ pilot_symbols: Complex pilot symbols of shape [batch, pilot_scs, pilot_symbols]
132
+ meta_data: TODO: Add complete type annotation.
133
+
134
+ Returns:
135
+ Estimated channel matrix of shape [batch, ofdm_scs, ofdm_symbols]
136
+ """
137
+
138
+ # Extract and move channel conditions to device
139
+ _, snr, delay_spread, max_dop_shift, _, _ = meta_data
140
+ channel_conditions = [
141
+ tensor.to(self.device)
142
+ for tensor in (snr, delay_spread, max_dop_shift)
143
+ ]
144
+
145
+ # Ensure input is on correct device
146
+ pilot_symbols = pilot_symbols.to(self.device)
147
+
148
+ # Process real and imaginary parts separately
149
+ real_estimate = self._forward_real_valued(pilot_symbols.real, channel_conditions)
150
+ imag_estimate = self._forward_real_valued(pilot_symbols.imag, channel_conditions)
151
+
152
+ # Combine into complex tensor
153
+ channel_estimate = torch.complex(real_estimate, imag_estimate)
154
+
155
+ return channel_estimate
156
+
157
+ def _forward_real_valued(self, x: torch.Tensor, channel_conditions: List[torch.Tensor]) -> torch.Tensor:
158
+ """
159
+ Process real-valued input through the estimation pipeline.
160
+
161
+ Args:
162
+ x: Real-valued input tensor [batch, pilot_features] or [batch, pilot_scs, pilot_symbols]
163
+
164
+ Returns:
165
+ Real-valued channel estimate [batch, ofdm_scs, ofdm_symbols]
166
+ """
167
+ batch_size = x.shape[0]
168
+
169
+ # Flatten spatial dimensions for linear upsampling
170
+ if x.dim() > 2:
171
+ x = x.view(batch_size, -1)
172
+
173
+ # Stage 1: Upsample from pilot grid to OFDM grid
174
+ upsampled = self.pilot_upsampler(x)
175
+
176
+ # Reshape for convolutional processing
177
+ upsampled_2d = upsampled.view(batch_size, 1, *self.ofdm_size)
178
+
179
+ # Stage 2: Initial convolutional enhancement
180
+ conv_enhanced = torch.squeeze(self.initial_enhancer(upsampled_2d), dim=1)
181
+
182
+ # Stage 3: Convert to patch embeddings
183
+ patch_embeddings = self.patch_embedder(conv_enhanced)
184
+
185
+ # Stage 4: Get conditioned channel encodings
186
+ encoded_channel_condition = self.channel_adapter(*channel_conditions)
187
+ conditioned_channel_encodings = torch.cat((patch_embeddings, encoded_channel_condition), dim=2)
188
+
189
+ # Stage 5: Transformer processing for long-range dependencies
190
+ transformer_output = self.transformer_encoder(conditioned_channel_encodings)
191
+
192
+ # Stage 6: Reconstruct spatial representation
193
+ reconstructed = self.patch_reconstructor(transformer_output)
194
+
195
+ # Stage 7: Apply residual connection
196
+ residual_combined = conv_enhanced + reconstructed
197
+
198
+ # Stage 8: Final convolutional refinement
199
+ refined_output = torch.squeeze(self.final_refiner(torch.unsqueeze(residual_combined, dim=1)), dim=1)
200
+
201
+ return refined_output
202
+
203
+ def get_model_info(self) -> dict:
204
+ """Return model configuration and statistics."""
205
+ return {
206
+ 'model_name': self.__class__.__name__,
207
+ 'ofdm_size': self.ofdm_size,
208
+ 'pilot_size': self.pilot_size,
209
+ 'patch_size': self.model_config.patch_size,
210
+ 'patch_length': self.patch_length,
211
+ 'model_dim': self.model_config.model_dim,
212
+ 'num_layers': self.model_config.num_layers,
213
+ 'device': str(self.device),
214
+ 'total_parameters': sum(p.numel() for p in self.parameters()),
215
+ 'trainable_parameters': sum(p.numel() for p in self.parameters() if p.requires_grad)
216
+ }