|
|
import torch |
|
|
import torch.nn as nn |
|
|
from utils import patch_reconstructor |
|
|
|
|
|
|
|
|
class LosNlosClassificationHead(nn.Module): |
|
|
""" |
|
|
Task head for LoS/NLoS classification. |
|
|
|
|
|
Takes flattened patch embeddings as input and outputs class logits for binary classification. |
|
|
|
|
|
Args: |
|
|
input_dim (tuple): (n_patches, d_model) — number of patches and feature dimension. |
|
|
""" |
|
|
def __init__(self, input_dim): |
|
|
super().__init__() |
|
|
n_patches, d_model = input_dim |
|
|
flattened_dim = n_patches * d_model |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(flattened_dim, 8), |
|
|
nn.BatchNorm1d(8), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(8, 2), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size = x.size(0) |
|
|
x = x.view(batch_size, -1) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
class BeamPredictionHead(nn.Module): |
|
|
""" |
|
|
Task head for mmWave beam index prediction. |
|
|
|
|
|
Processes flattened patch embeddings and outputs logits over 64 possible beam indices. |
|
|
|
|
|
Args: |
|
|
input_dim (tuple): (n_patches, d_model) — number of patches and feature dimension. |
|
|
""" |
|
|
def __init__(self, input_dim): |
|
|
super().__init__() |
|
|
n_patches, d_model = input_dim |
|
|
flattened_dim = n_patches * d_model |
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(flattened_dim, 256), |
|
|
nn.BatchNorm1d(256), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, 128), |
|
|
nn.BatchNorm1d(128), |
|
|
nn.ReLU(), |
|
|
nn.Linear(128, 64) |
|
|
) |
|
|
def forward(self, x): |
|
|
batch_size = x.size(0) |
|
|
x = x.view(batch_size, -1) |
|
|
x = self.classifier(x) |
|
|
return x |
|
|
|
|
|
class ChannelInterpolationHead(nn.Module): |
|
|
""" |
|
|
Task head for reconstructing missing channel values from patch embeddings. |
|
|
|
|
|
Applies a linear layer to each patch and reconstructs the full channel using patch_reconstructor. |
|
|
|
|
|
Args: |
|
|
input_dim (tuple): (n_patches, d_model). |
|
|
output_dim (tuple): (target_channels, n_rows, n_cols) — shape of the output channel matrix. |
|
|
""" |
|
|
def __init__(self, input_dim, output_dim): |
|
|
super().__init__() |
|
|
n_patches, d_model = input_dim |
|
|
target_channels, self.n_rows, self.n_cols = output_dim |
|
|
self.fcn = nn.Sequential( |
|
|
nn.Linear(d_model, 32) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, n_patches, d_model = x.size() |
|
|
x = x.reshape(batch_size * n_patches, d_model) |
|
|
x = self.fcn(x) |
|
|
x = x.reshape(batch_size, n_patches, 32) |
|
|
x = patch_reconstructor(x, self.n_rows, self.n_cols) |
|
|
return x |
|
|
|
|
|
class ChannelEstimationHead(nn.Module): |
|
|
""" |
|
|
Task head for full channel estimation from embeddings. |
|
|
|
|
|
Similar to interpolation but typically used for denoising or noisy reconstruction. |
|
|
|
|
|
Args: |
|
|
input_dim (tuple): (n_patches, d_model). |
|
|
output_dim (tuple): (target_channels, n_rows, n_cols) — shape of the target full-resolution channel. |
|
|
""" |
|
|
def __init__(self, input_dim, output_dim): |
|
|
super().__init__() |
|
|
n_patches, d_model = input_dim |
|
|
target_channels, self.n_rows, self.n_cols = output_dim |
|
|
self.fcn = nn.Sequential( |
|
|
nn.Linear(d_model, 32) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size, n_patches, d_model = x.size() |
|
|
x = x.reshape(batch_size * n_patches, d_model) |
|
|
x = self.fcn(x) |
|
|
x = x.reshape(batch_size, n_patches, 32) |
|
|
x = patch_reconstructor(x, self.n_rows, self.n_cols) |
|
|
return x |
|
|
|
|
|
class ChannelChartingHead(nn.Module): |
|
|
""" |
|
|
Task head for 2D channel charting (e.g., learning spatial topology). |
|
|
|
|
|
Reduces the flattened embeddings into 2D coordinates. |
|
|
|
|
|
Args: |
|
|
input_dim (tuple): (n_patches, d_model). |
|
|
""" |
|
|
def __init__(self, input_dim): |
|
|
super().__init__() |
|
|
n_patches, d_model = input_dim |
|
|
flattened_dim = n_patches * d_model |
|
|
self.fcn = nn.Sequential( |
|
|
nn.Linear(flattened_dim, 64), |
|
|
nn.ReLU(), |
|
|
nn.Linear(64, 32), |
|
|
nn.ReLU(), |
|
|
nn.Linear(32, 2) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size = x.size(0) |
|
|
x = x.view(batch_size, -1) |
|
|
x = self.fcn(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
training_configs = [ |
|
|
{ |
|
|
"task": "LosNlosClassification", |
|
|
"optimizer_config": {"lr": 1e-3}, |
|
|
"scheduler": {"step_size": 20, "gamma": 0.5}, |
|
|
"epochs": 200, |
|
|
"batch_size": 128, |
|
|
"seed": 42, |
|
|
"fine_tune_layers": ["layers.9", "layers.10", "layers.11"], |
|
|
"input_type": "cls_emb", |
|
|
"selected_tokens": None |
|
|
}, |
|
|
{ |
|
|
"task": "BeamPrediction", |
|
|
"optimizer_config": {"lr": 1e-3}, |
|
|
"scheduler": {"step_size": 20, "gamma": 0.8}, |
|
|
"epochs": 70, |
|
|
"batch_size": 128, |
|
|
"seed": 42, |
|
|
"fine_tune_layers": "full", |
|
|
"input_type": "mean_pooled", |
|
|
"selected_tokens": None |
|
|
}, |
|
|
{ |
|
|
"task": "ChannelInterpolation", |
|
|
"optimizer_config": {"lr": 1e-2}, |
|
|
"scheduler": {"step_size": 25, "gamma": 0.2}, |
|
|
"epochs": 100, |
|
|
"batch_size": 128, |
|
|
"seed": 42, |
|
|
"fine_tune_layers": ["layers.10", "layers.11"], |
|
|
"input_type": "channel_emb", |
|
|
"selected_tokens": None |
|
|
}, |
|
|
{ |
|
|
"task": "ChannelEstimation", |
|
|
"optimizer_config": {"lr": 1e-2}, |
|
|
"scheduler": {"step_size": 50, "gamma": 0.3}, |
|
|
"epochs": 200, |
|
|
"batch_size": 128, |
|
|
"seed": 42, |
|
|
"fine_tune_layers": ["layers.9", "layers.10", "layers.11"], |
|
|
"input_type": "channel_emb", |
|
|
"selected_tokens": None |
|
|
}, |
|
|
{ |
|
|
"task": "ChannelCharting", |
|
|
"optimizer_config": {"lr": 1e-3}, |
|
|
"scheduler": {"step_size": 40, "gamma": 0.6}, |
|
|
"epochs": 150, |
|
|
"batch_size": 128, |
|
|
"seed": 42, |
|
|
"fine_tune_layers": ["layers.10", "layers.11"], |
|
|
"input_type": "mean_pooled", |
|
|
"selected_tokens": None |
|
|
} |
|
|
] |
|
|
|