syeedalireza's picture
Upload folder using huggingface_hub
f43cdb8 verified
"""
Training script for the modulation classifier.
Expects data in DATA_DIR: either synthetic or preprocessed RadioML-style (see README).
"""
import os
import argparse
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from model import ModulationClassifier
def get_synthetic_dataloaders(data_dir: str, seq_len: int = 128, num_classes: int = 6, batch_size: int = 64):
"""Build simple synthetic I/Q data if no real dataset is present (for testing the pipeline)."""
n_train, n_val = 2000, 400
torch.manual_seed(42)
# Random I/Q-like data
X_train = torch.randn(n_train, 2, seq_len)
y_train = torch.randint(0, num_classes, (n_train,))
X_val = torch.randn(n_val, 2, seq_len)
y_val = torch.randint(0, num_classes, (n_val,))
train_ds = TensorDataset(X_train, y_train)
val_ds = TensorDataset(X_val, y_val)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=0)
return train_loader, val_loader
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="./data", help="Path to data directory")
parser.add_argument("--seq_len", type=int, default=128)
parser.add_argument("--num_classes", type=int, default=6)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
train_loader, val_loader = get_synthetic_dataloaders(
args.data_dir, seq_len=args.seq_len, num_classes=args.num_classes, batch_size=args.batch_size
)
model = ModulationClassifier(num_classes=args.num_classes, seq_len=args.seq_len).to(args.device)
opt = torch.optim.Adam(model.parameters(), lr=args.lr)
criterion = nn.CrossEntropyLoss()
for epoch in range(args.epochs):
model.train()
train_loss = 0.0
for x, y in train_loader:
x, y = x.to(args.device), y.to(args.device)
opt.zero_grad()
logits = model(x)
loss = criterion(logits, y)
loss.backward()
opt.step()
train_loss += loss.item()
train_loss /= len(train_loader)
model.eval()
correct, total = 0, 0
with torch.no_grad():
for x, y in val_loader:
x, y = x.to(args.device), y.to(args.device)
logits = model(x)
pred = logits.argmax(dim=1)
total += y.size(0)
correct += (pred == y).sum().item()
acc = correct / total
print(f"Epoch {epoch + 1}/{args.epochs} train_loss={train_loss:.4f} val_acc={acc:.4f}")
os.makedirs(args.data_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(args.data_dir, "modulation_classifier.pt"))
print("Model saved to", os.path.join(args.data_dir, "modulation_classifier.pt"))
if __name__ == "__main__":
main()