| """
|
| Training script for the modulation classifier.
|
| Expects data in DATA_DIR: either synthetic or preprocessed RadioML-style (see README).
|
| """
|
|
|
| import os
|
| import argparse
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import DataLoader, TensorDataset
|
|
|
| from model import ModulationClassifier
|
|
|
|
|
| def get_synthetic_dataloaders(data_dir: str, seq_len: int = 128, num_classes: int = 6, batch_size: int = 64):
|
| """Build simple synthetic I/Q data if no real dataset is present (for testing the pipeline)."""
|
| n_train, n_val = 2000, 400
|
| torch.manual_seed(42)
|
|
|
| X_train = torch.randn(n_train, 2, seq_len)
|
| y_train = torch.randint(0, num_classes, (n_train,))
|
| X_val = torch.randn(n_val, 2, seq_len)
|
| y_val = torch.randint(0, num_classes, (n_val,))
|
|
|
| train_ds = TensorDataset(X_train, y_train)
|
| val_ds = TensorDataset(X_val, y_val)
|
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
|
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
|
| return train_loader, val_loader
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser()
|
| parser.add_argument("--data_dir", type=str, default="./data", help="Path to data directory")
|
| parser.add_argument("--seq_len", type=int, default=128)
|
| parser.add_argument("--num_classes", type=int, default=6)
|
| parser.add_argument("--epochs", type=int, default=10)
|
| parser.add_argument("--batch_size", type=int, default=64)
|
| parser.add_argument("--lr", type=float, default=1e-3)
|
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
|
| args = parser.parse_args()
|
|
|
| train_loader, val_loader = get_synthetic_dataloaders(
|
| args.data_dir, seq_len=args.seq_len, num_classes=args.num_classes, batch_size=args.batch_size
|
| )
|
|
|
| model = ModulationClassifier(num_classes=args.num_classes, seq_len=args.seq_len).to(args.device)
|
| opt = torch.optim.Adam(model.parameters(), lr=args.lr)
|
| criterion = nn.CrossEntropyLoss()
|
|
|
| for epoch in range(args.epochs):
|
| model.train()
|
| train_loss = 0.0
|
| for x, y in train_loader:
|
| x, y = x.to(args.device), y.to(args.device)
|
| opt.zero_grad()
|
| logits = model(x)
|
| loss = criterion(logits, y)
|
| loss.backward()
|
| opt.step()
|
| train_loss += loss.item()
|
| train_loss /= len(train_loader)
|
|
|
| model.eval()
|
| correct, total = 0, 0
|
| with torch.no_grad():
|
| for x, y in val_loader:
|
| x, y = x.to(args.device), y.to(args.device)
|
| logits = model(x)
|
| pred = logits.argmax(dim=1)
|
| total += y.size(0)
|
| correct += (pred == y).sum().item()
|
| acc = correct / total
|
| print(f"Epoch {epoch + 1}/{args.epochs} train_loss={train_loss:.4f} val_acc={acc:.4f}")
|
|
|
| os.makedirs(args.data_dir, exist_ok=True)
|
| torch.save(model.state_dict(), os.path.join(args.data_dir, "modulation_classifier.pt"))
|
| print("Model saved to", os.path.join(args.data_dir, "modulation_classifier.pt"))
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|