""" Training script for the modulation classifier. Expects data in DATA_DIR: either synthetic or preprocessed RadioML-style (see README). """ import os import argparse import torch import torch.nn as nn from torch.utils.data import DataLoader, TensorDataset from model import ModulationClassifier def get_synthetic_dataloaders(data_dir: str, seq_len: int = 128, num_classes: int = 6, batch_size: int = 64): """Build simple synthetic I/Q data if no real dataset is present (for testing the pipeline).""" n_train, n_val = 2000, 400 torch.manual_seed(42) # Random I/Q-like data X_train = torch.randn(n_train, 2, seq_len) y_train = torch.randint(0, num_classes, (n_train,)) X_val = torch.randn(n_val, 2, seq_len) y_val = torch.randint(0, num_classes, (n_val,)) train_ds = TensorDataset(X_train, y_train) val_ds = TensorDataset(X_val, y_val) train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0) return train_loader, val_loader def main(): parser = argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, default="./data", help="Path to data directory") parser.add_argument("--seq_len", type=int, default=128) parser.add_argument("--num_classes", type=int, default=6) parser.add_argument("--epochs", type=int, default=10) parser.add_argument("--batch_size", type=int, default=64) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() train_loader, val_loader = get_synthetic_dataloaders( args.data_dir, seq_len=args.seq_len, num_classes=args.num_classes, batch_size=args.batch_size ) model = ModulationClassifier(num_classes=args.num_classes, seq_len=args.seq_len).to(args.device) opt = torch.optim.Adam(model.parameters(), lr=args.lr) criterion = nn.CrossEntropyLoss() for epoch in range(args.epochs): model.train() train_loss = 0.0 for x, y in train_loader: x, y = x.to(args.device), y.to(args.device) opt.zero_grad() logits = model(x) loss = criterion(logits, y) loss.backward() opt.step() train_loss += loss.item() train_loss /= len(train_loader) model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in val_loader: x, y = x.to(args.device), y.to(args.device) logits = model(x) pred = logits.argmax(dim=1) total += y.size(0) correct += (pred == y).sum().item() acc = correct / total print(f"Epoch {epoch + 1}/{args.epochs} train_loss={train_loss:.4f} val_acc={acc:.4f}") os.makedirs(args.data_dir, exist_ok=True) torch.save(model.state_dict(), os.path.join(args.data_dir, "modulation_classifier.pt")) print("Model saved to", os.path.join(args.data_dir, "modulation_classifier.pt")) if __name__ == "__main__": main()