Commit
·
71dbdc8
1
Parent(s):
687eaba
minor fixed on src/main/parser.py
Browse files- src/main.py +38 -7
- src/main/parser.py +16 -61
- src/main/train_helpers.py +9 -12
- src/main/trainer.py +15 -22
src/main.py
CHANGED
|
@@ -32,13 +32,14 @@ Dataset Requirements:
|
|
| 32 |
└── ...
|
| 33 |
|
| 34 |
Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
|
| 35 |
-
- H[:, :, 0]: Ground truth channel
|
| 36 |
-
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions
|
| 37 |
-
- H[:, :, 2]:
|
| 38 |
"""
|
| 39 |
|
| 40 |
import logging
|
| 41 |
import sys
|
|
|
|
| 42 |
from pathlib import Path
|
| 43 |
|
| 44 |
from src.main.parser import parse_arguments
|
|
@@ -47,18 +48,26 @@ from src.config import load_config
|
|
| 47 |
from src.config.schemas import ModelConfig
|
| 48 |
|
| 49 |
|
| 50 |
-
def setup_logging(log_level: str) -> None:
|
| 51 |
"""Set up logging configuration.
|
| 52 |
|
| 53 |
Args:
|
| 54 |
log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
|
|
|
|
|
| 55 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
logging.basicConfig(
|
| 57 |
level=getattr(logging, log_level.upper()),
|
| 58 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 59 |
handlers=[
|
| 60 |
logging.StreamHandler(sys.stdout),
|
| 61 |
-
logging.FileHandler(
|
| 62 |
]
|
| 63 |
)
|
| 64 |
|
|
@@ -70,7 +79,7 @@ def main() -> None:
|
|
| 70 |
args = parse_arguments()
|
| 71 |
|
| 72 |
# Set up logging
|
| 73 |
-
setup_logging(args.python_log_level)
|
| 74 |
logger = logging.getLogger(__name__)
|
| 75 |
|
| 76 |
logger.info("Starting OFDM channel estimation model training")
|
|
@@ -86,13 +95,35 @@ def main() -> None:
|
|
| 86 |
args.model_config_path
|
| 87 |
)
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
logger.info("Configuration loaded successfully")
|
| 90 |
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
|
| 91 |
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
|
|
|
|
|
|
|
| 92 |
if model_config.model_type == "linear":
|
| 93 |
logger.info(f"Linear model with device: {model_config.device}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
else:
|
| 95 |
-
logger.
|
| 96 |
|
| 97 |
# Start training
|
| 98 |
logger.info("Initializing training...")
|
|
|
|
| 32 |
└── ...
|
| 33 |
|
| 34 |
Each .mat file must contain variable 'H' with shape [subcarriers, symbols, 3]:
|
| 35 |
+
- H[:, :, 0]: Ground truth channel (complex-valued channel matrix)
|
| 36 |
+
- H[:, :, 1]: LS channel estimate with zeros for non-pilot positions (complex-valued) - used as input to models
|
| 37 |
+
- H[:, :, 2]: Bilinear interpolated LS channel estimate (complex-valued) - available but currently unused
|
| 38 |
"""
|
| 39 |
|
| 40 |
import logging
|
| 41 |
import sys
|
| 42 |
+
from datetime import datetime
|
| 43 |
from pathlib import Path
|
| 44 |
|
| 45 |
from src.main.parser import parse_arguments
|
|
|
|
| 48 |
from src.config.schemas import ModelConfig
|
| 49 |
|
| 50 |
|
| 51 |
+
def setup_logging(log_level: str, log_dir: Path, exp_id: str) -> None:
|
| 52 |
"""Set up logging configuration.
|
| 53 |
|
| 54 |
Args:
|
| 55 |
log_level: Logging level string (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
| 56 |
+
log_dir: Directory path for log files
|
| 57 |
+
exp_id: Experiment identifier for log file naming
|
| 58 |
"""
|
| 59 |
+
# Create logs directory if it doesn't exist
|
| 60 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# Create log file path using exp_id for easy matching
|
| 63 |
+
log_file = log_dir / f"training_{exp_id}.log"
|
| 64 |
+
|
| 65 |
logging.basicConfig(
|
| 66 |
level=getattr(logging, log_level.upper()),
|
| 67 |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 68 |
handlers=[
|
| 69 |
logging.StreamHandler(sys.stdout),
|
| 70 |
+
logging.FileHandler(log_file)
|
| 71 |
]
|
| 72 |
)
|
| 73 |
|
|
|
|
| 79 |
args = parse_arguments()
|
| 80 |
|
| 81 |
# Set up logging
|
| 82 |
+
setup_logging(args.python_log_level, args.python_log_dir, args.exp_id)
|
| 83 |
logger = logging.getLogger(__name__)
|
| 84 |
|
| 85 |
logger.info("Starting OFDM channel estimation model training")
|
|
|
|
| 95 |
args.model_config_path
|
| 96 |
)
|
| 97 |
|
| 98 |
+
# Validate model type consistency
|
| 99 |
+
expected_model_types = {
|
| 100 |
+
"linear": "linear",
|
| 101 |
+
"fortitran": "fortitran",
|
| 102 |
+
"adafortitran": "adafortitran"
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
if args.model_name not in expected_model_types:
|
| 106 |
+
raise ValueError(f"Unknown model name: {args.model_name}. Expected one of: {list(expected_model_types.keys())}")
|
| 107 |
+
|
| 108 |
+
if model_config.model_type != expected_model_types[args.model_name]:
|
| 109 |
+
raise ValueError(f"Model type mismatch: config specifies '{model_config.model_type}' but model name is '{args.model_name}'")
|
| 110 |
+
|
| 111 |
logger.info("Configuration loaded successfully")
|
| 112 |
logger.info(f"OFDM dimensions: {system_config.ofdm.num_scs} subcarriers x {system_config.ofdm.num_symbols} symbols")
|
| 113 |
logger.info(f"Pilot dimensions: {system_config.pilot.num_scs} subcarriers x {system_config.pilot.num_symbols} symbols")
|
| 114 |
+
|
| 115 |
+
# Log model-specific information
|
| 116 |
if model_config.model_type == "linear":
|
| 117 |
logger.info(f"Linear model with device: {model_config.device}")
|
| 118 |
+
elif model_config.model_type == "fortitran":
|
| 119 |
+
logger.info(f"FortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
|
| 120 |
+
logger.info(f"Channel adaptation: disabled")
|
| 121 |
+
elif model_config.model_type == "adafortitran":
|
| 122 |
+
logger.info(f"AdaFortiTran model: {model_config.num_layers} layers, {model_config.model_dim} dimensions")
|
| 123 |
+
logger.info(f"Channel adaptation: enabled")
|
| 124 |
+
logger.info(f"Adaptive token length: {model_config.adaptive_token_length}")
|
| 125 |
else:
|
| 126 |
+
logger.warning(f"Unknown model type: {model_config.model_type}")
|
| 127 |
|
| 128 |
# Start training
|
| 129 |
logger.info("Initializing training...")
|
src/main/parser.py
CHANGED
|
@@ -9,18 +9,10 @@ of training runs.
|
|
| 9 |
|
| 10 |
from pathlib import Path
|
| 11 |
import argparse
|
| 12 |
-
from enum import Enum
|
| 13 |
from pydantic import BaseModel, Field, model_validator
|
| 14 |
from typing import Self
|
| 15 |
|
| 16 |
|
| 17 |
-
class LossType(Enum):
|
| 18 |
-
"""Enumeration of supported loss functions."""
|
| 19 |
-
MSE = "mse"
|
| 20 |
-
MAE = "mae"
|
| 21 |
-
HUBER = "huber"
|
| 22 |
-
|
| 23 |
-
|
| 24 |
class TrainingArguments(BaseModel):
|
| 25 |
"""Container for OFDM model training arguments.
|
| 26 |
|
|
@@ -29,7 +21,7 @@ class TrainingArguments(BaseModel):
|
|
| 29 |
|
| 30 |
Attributes:
|
| 31 |
# Model Configuration
|
| 32 |
-
model_name: Supports
|
| 33 |
system_config_path: Path to OFDM system configuration file
|
| 34 |
model_config_path: Path to model configuration file
|
| 35 |
|
|
@@ -42,17 +34,15 @@ class TrainingArguments(BaseModel):
|
|
| 42 |
exp_id: Experiment identifier string
|
| 43 |
python_log_level: Logging verbosity level
|
| 44 |
tensorboard_log_dir: Directory for tensorboard logs
|
|
|
|
| 45 |
|
| 46 |
# Training Hyperparameters
|
| 47 |
batch_size: Number of samples per batch
|
| 48 |
lr: Learning rate for optimizer
|
| 49 |
max_epoch: Maximum number of training epochs
|
| 50 |
patience: Early stopping patience in epochs
|
| 51 |
-
loss_type: Type of loss function to use
|
| 52 |
-
return_type: Type of data to return from dataset
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
cuda: CUDA device index
|
| 56 |
test_every_n: Number of epochs between test evaluations
|
| 57 |
"""
|
| 58 |
|
|
@@ -70,17 +60,15 @@ class TrainingArguments(BaseModel):
|
|
| 70 |
exp_id: str = Field(..., description="Experiment identifier for log folder naming")
|
| 71 |
python_log_level: str = Field(default="INFO", description="Logger level for python logging module")
|
| 72 |
tensorboard_log_dir: Path = Field(default=Path("runs"), description="Directory for tensorboard logs")
|
|
|
|
| 73 |
|
| 74 |
# Training Hyperparameters
|
| 75 |
batch_size: int = Field(default=64, gt=0, description="Training batch size")
|
| 76 |
lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
|
| 77 |
max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
|
| 78 |
patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
|
| 79 |
-
loss_type: LossType = Field(default=LossType.MSE, description="Loss function type")
|
| 80 |
-
return_type: str = Field(default="complex", description="Type of data to return from dataset")
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
cuda: int = Field(default=0, ge=0, description="CUDA device index (0 for single GPU)")
|
| 84 |
test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
|
| 85 |
|
| 86 |
@model_validator(mode='after')
|
|
@@ -133,8 +121,8 @@ def parse_arguments() -> TrainingArguments:
|
|
| 133 |
'--model_name',
|
| 134 |
type=str,
|
| 135 |
required=True,
|
| 136 |
-
choices=['
|
| 137 |
-
help='Model type to train (
|
| 138 |
)
|
| 139 |
required.add_argument(
|
| 140 |
'--system_config_path',
|
|
@@ -187,6 +175,12 @@ def parse_arguments() -> TrainingArguments:
|
|
| 187 |
default="runs",
|
| 188 |
help='Directory for tensorboard logs'
|
| 189 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
optional.add_argument(
|
| 191 |
'--test_every_n',
|
| 192 |
type=int,
|
|
@@ -211,55 +205,16 @@ def parse_arguments() -> TrainingArguments:
|
|
| 211 |
default=64,
|
| 212 |
help='Training batch size'
|
| 213 |
)
|
| 214 |
-
|
| 215 |
-
'--cuda',
|
| 216 |
-
type=int,
|
| 217 |
-
default=0,
|
| 218 |
-
help='CUDA device index (0 for single GPU)'
|
| 219 |
-
)
|
| 220 |
optional.add_argument(
|
| 221 |
'--lr',
|
| 222 |
type=float,
|
| 223 |
default=1e-3,
|
| 224 |
help='Initial learning rate'
|
| 225 |
)
|
| 226 |
-
optional.add_argument(
|
| 227 |
-
'--loss_type',
|
| 228 |
-
type=str,
|
| 229 |
-
default="mse",
|
| 230 |
-
choices=['mse', 'mae', 'huber'],
|
| 231 |
-
help='Loss function type'
|
| 232 |
-
)
|
| 233 |
-
optional.add_argument(
|
| 234 |
-
'--return_type',
|
| 235 |
-
type=str,
|
| 236 |
-
default="complex",
|
| 237 |
-
choices=['complex', 'real'],
|
| 238 |
-
help='Type of data to return from dataset'
|
| 239 |
-
)
|
| 240 |
|
| 241 |
-
args = parser.parse_args()
|
| 242 |
|
| 243 |
-
|
| 244 |
-
loss_type = LossType(args.loss_type)
|
| 245 |
|
| 246 |
# Create and validate TrainingArguments
|
| 247 |
-
return TrainingArguments(
|
| 248 |
-
model_name=args.model_name,
|
| 249 |
-
system_config_path=args.system_config_path,
|
| 250 |
-
model_config_path=args.model_config_path,
|
| 251 |
-
train_set=args.train_set,
|
| 252 |
-
val_set=args.val_set,
|
| 253 |
-
test_set=args.test_set,
|
| 254 |
-
exp_id=args.exp_id,
|
| 255 |
-
python_log_level=args.python_log_level,
|
| 256 |
-
tensorboard_log_dir=args.tensorboard_log_dir,
|
| 257 |
-
batch_size=args.batch_size,
|
| 258 |
-
lr=args.lr,
|
| 259 |
-
max_epoch=args.max_epoch,
|
| 260 |
-
patience=args.patience,
|
| 261 |
-
loss_type=loss_type,
|
| 262 |
-
return_type=args.return_type,
|
| 263 |
-
cuda=args.cuda,
|
| 264 |
-
test_every_n=args.test_every_n
|
| 265 |
-
)
|
|
|
|
| 9 |
|
| 10 |
from pathlib import Path
|
| 11 |
import argparse
|
|
|
|
| 12 |
from pydantic import BaseModel, Field, model_validator
|
| 13 |
from typing import Self
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class TrainingArguments(BaseModel):
|
| 17 |
"""Container for OFDM model training arguments.
|
| 18 |
|
|
|
|
| 21 |
|
| 22 |
Attributes:
|
| 23 |
# Model Configuration
|
| 24 |
+
model_name: Supports linear, adafortitran, or fortitran training
|
| 25 |
system_config_path: Path to OFDM system configuration file
|
| 26 |
model_config_path: Path to model configuration file
|
| 27 |
|
|
|
|
| 34 |
exp_id: Experiment identifier string
|
| 35 |
python_log_level: Logging verbosity level
|
| 36 |
tensorboard_log_dir: Directory for tensorboard logs
|
| 37 |
+
python_log_dir: Directory for python logging files
|
| 38 |
|
| 39 |
# Training Hyperparameters
|
| 40 |
batch_size: Number of samples per batch
|
| 41 |
lr: Learning rate for optimizer
|
| 42 |
max_epoch: Maximum number of training epochs
|
| 43 |
patience: Early stopping patience in epochs
|
|
|
|
|
|
|
| 44 |
|
| 45 |
+
# Evaluation
|
|
|
|
| 46 |
test_every_n: Number of epochs between test evaluations
|
| 47 |
"""
|
| 48 |
|
|
|
|
| 60 |
exp_id: str = Field(..., description="Experiment identifier for log folder naming")
|
| 61 |
python_log_level: str = Field(default="INFO", description="Logger level for python logging module")
|
| 62 |
tensorboard_log_dir: Path = Field(default=Path("runs"), description="Directory for tensorboard logs")
|
| 63 |
+
python_log_dir: Path = Field(default=Path("logs"), description="Directory for python logging files")
|
| 64 |
|
| 65 |
# Training Hyperparameters
|
| 66 |
batch_size: int = Field(default=64, gt=0, description="Training batch size")
|
| 67 |
lr: float = Field(default=1e-3, gt=0, description="Initial learning rate")
|
| 68 |
max_epoch: int = Field(default=10, gt=0, description="Maximum number of training epochs")
|
| 69 |
patience: int = Field(default=3, gt=0, description="Early stopping patience (epochs)")
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
# Evaluation
|
|
|
|
| 72 |
test_every_n: int = Field(default=10, gt=0, description="Test model every N epochs")
|
| 73 |
|
| 74 |
@model_validator(mode='after')
|
|
|
|
| 121 |
'--model_name',
|
| 122 |
type=str,
|
| 123 |
required=True,
|
| 124 |
+
choices=['linear', 'adafortitran', 'fortitran'],
|
| 125 |
+
help='Model type to train (linear, adafortitran, or fortitran)'
|
| 126 |
)
|
| 127 |
required.add_argument(
|
| 128 |
'--system_config_path',
|
|
|
|
| 175 |
default="runs",
|
| 176 |
help='Directory for tensorboard logs'
|
| 177 |
)
|
| 178 |
+
optional.add_argument(
|
| 179 |
+
'--python_log_dir',
|
| 180 |
+
type=Path,
|
| 181 |
+
default="logs",
|
| 182 |
+
help='Directory for python logging files'
|
| 183 |
+
)
|
| 184 |
optional.add_argument(
|
| 185 |
'--test_every_n',
|
| 186 |
type=int,
|
|
|
|
| 205 |
default=64,
|
| 206 |
help='Training batch size'
|
| 207 |
)
|
| 208 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
optional.add_argument(
|
| 210 |
'--lr',
|
| 211 |
type=float,
|
| 212 |
default=1e-3,
|
| 213 |
help='Initial learning rate'
|
| 214 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
|
|
|
| 216 |
|
| 217 |
+
args = parser.parse_args()
|
|
|
|
| 218 |
|
| 219 |
# Create and validate TrainingArguments
|
| 220 |
+
return TrainingArguments(**vars(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/main/train_helpers.py
CHANGED
|
@@ -120,7 +120,7 @@ def eval_model(
|
|
| 120 |
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
| 121 |
val_loss += (2 * output.item() * batch[0].size(0))
|
| 122 |
|
| 123 |
-
val_loss /= len(eval_dataloader
|
| 124 |
return val_loss
|
| 125 |
|
| 126 |
|
|
@@ -206,7 +206,7 @@ def train_epoch(
|
|
| 206 |
train_loss += (2 * output.item() * batch[0].size(0))
|
| 207 |
|
| 208 |
scheduler.step()
|
| 209 |
-
train_loss /= len(train_dataloader
|
| 210 |
return train_loss
|
| 211 |
|
| 212 |
|
|
@@ -225,20 +225,17 @@ def _forward_pass(batch: BatchType, model: nn.Module) -> Tuple[ComplexTensor, Co
|
|
| 225 |
Tuple of (processed_estimated_channel, ideal_channel)
|
| 226 |
|
| 227 |
Raises:
|
| 228 |
-
ValueError: If model
|
| 229 |
"""
|
| 230 |
estimated_channel, ideal_channel, meta_data = batch
|
| 231 |
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
estimated_channel =
|
| 236 |
-
elif model.name == "adafortitran":
|
| 237 |
-
h_est_re = model(torch.real(estimated_channel), meta_data)
|
| 238 |
-
h_est_im = model(torch.imag(estimated_channel), meta_data)
|
| 239 |
-
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 240 |
else:
|
| 241 |
-
|
|
|
|
| 242 |
|
| 243 |
return estimated_channel, ideal_channel.to(model.device)
|
| 244 |
|
|
|
|
| 120 |
output = _compute_loss(estimated_channel, ideal_channel, loss_fn)
|
| 121 |
val_loss += (2 * output.item() * batch[0].size(0))
|
| 122 |
|
| 123 |
+
val_loss /= sum(len(batch[0]) for batch in eval_dataloader)
|
| 124 |
return val_loss
|
| 125 |
|
| 126 |
|
|
|
|
| 206 |
train_loss += (2 * output.item() * batch[0].size(0))
|
| 207 |
|
| 208 |
scheduler.step()
|
| 209 |
+
train_loss /= sum(len(batch[0]) for batch in train_dataloader)
|
| 210 |
return train_loss
|
| 211 |
|
| 212 |
|
|
|
|
| 225 |
Tuple of (processed_estimated_channel, ideal_channel)
|
| 226 |
|
| 227 |
Raises:
|
| 228 |
+
ValueError: If model type is not recognized
|
| 229 |
"""
|
| 230 |
estimated_channel, ideal_channel, meta_data = batch
|
| 231 |
|
| 232 |
+
# All models now handle complex input directly
|
| 233 |
+
if hasattr(model, 'use_channel_adaptation') and model.use_channel_adaptation:
|
| 234 |
+
# AdaFortiTran uses meta_data for channel adaptation
|
| 235 |
+
estimated_channel = model(estimated_channel, meta_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
else:
|
| 237 |
+
# Linear and FortiTran models don't use meta_data
|
| 238 |
+
estimated_channel = model(estimated_channel)
|
| 239 |
|
| 240 |
return estimated_channel, ideal_channel.to(model.device)
|
| 241 |
|
src/main/trainer.py
CHANGED
|
@@ -81,7 +81,7 @@ class ModelTrainer:
|
|
| 81 |
self.system_config = system_config
|
| 82 |
self.model_config = model_config
|
| 83 |
self.args = args
|
| 84 |
-
self.device = torch.device(
|
| 85 |
self.writer = self._setup_tensorboard()
|
| 86 |
self.logger = logging.getLogger(__name__)
|
| 87 |
|
|
@@ -120,14 +120,12 @@ class ModelTrainer:
|
|
| 120 |
Returns:
|
| 121 |
Initialized model instance of the specified type
|
| 122 |
"""
|
| 123 |
-
if self.args.model_name
|
| 124 |
-
model
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
else:
|
| 130 |
-
raise ValueError(f"Unknown model name: {self.args.model_name}")
|
| 131 |
num_params, model_summary = get_model_details(model)
|
| 132 |
self.logger.info("\n" + model_summary)
|
| 133 |
self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
|
|
@@ -280,20 +278,15 @@ class ModelTrainer:
|
|
| 280 |
|
| 281 |
def _forward_pass(self, batch, model):
|
| 282 |
estimated_channel, ideal_channel, meta_data = batch
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
h_est_re = model(torch.real(estimated_channel), meta_data)
|
| 289 |
-
h_est_im = model(torch.imag(estimated_channel), meta_data)
|
| 290 |
-
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 291 |
-
elif isinstance(model, LinearEstimator):
|
| 292 |
-
h_est_re = model(torch.real(estimated_channel))
|
| 293 |
-
h_est_im = model(torch.imag(estimated_channel))
|
| 294 |
-
estimated_channel = torch.complex(h_est_re, h_est_im)
|
| 295 |
else:
|
| 296 |
-
|
|
|
|
|
|
|
| 297 |
return estimated_channel, ideal_channel.to(model.device)
|
| 298 |
|
| 299 |
def _train_epoch(self):
|
|
|
|
| 81 |
self.system_config = system_config
|
| 82 |
self.model_config = model_config
|
| 83 |
self.args = args
|
| 84 |
+
self.device = torch.device(model_config.device)
|
| 85 |
self.writer = self._setup_tensorboard()
|
| 86 |
self.logger = logging.getLogger(__name__)
|
| 87 |
|
|
|
|
| 120 |
Returns:
|
| 121 |
Initialized model instance of the specified type
|
| 122 |
"""
|
| 123 |
+
if self.args.model_name not in self.MODEL_REGISTRY:
|
| 124 |
+
raise ValueError(f"Unknown model name: {self.args.model_name}. Available: {list(self.MODEL_REGISTRY.keys())}")
|
| 125 |
+
|
| 126 |
+
model_class = self.MODEL_REGISTRY[self.args.model_name]
|
| 127 |
+
model = model_class(self.system_config, self.model_config)
|
| 128 |
+
|
|
|
|
|
|
|
| 129 |
num_params, model_summary = get_model_details(model)
|
| 130 |
self.logger.info("\n" + model_summary)
|
| 131 |
self.logger.info(f"Model name: {self.args.model_name} | Number of parameters: {num_params}")
|
|
|
|
| 278 |
|
| 279 |
def _forward_pass(self, batch, model):
|
| 280 |
estimated_channel, ideal_channel, meta_data = batch
|
| 281 |
+
|
| 282 |
+
# All models now handle complex input directly
|
| 283 |
+
if isinstance(model, AdaFortiTranEstimator):
|
| 284 |
+
# AdaFortiTran uses meta_data for channel adaptation
|
| 285 |
+
estimated_channel = model(estimated_channel, meta_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
else:
|
| 287 |
+
# Linear and FortiTran models don't use meta_data
|
| 288 |
+
estimated_channel = model(estimated_channel)
|
| 289 |
+
|
| 290 |
return estimated_channel, ideal_channel.to(model.device)
|
| 291 |
|
| 292 |
def _train_epoch(self):
|