Radon-Averaging / main.py
jangwon-kim-cocel's picture
Upload 6 files
0b63116 verified
import argparse
import torch
from model import RA
from trainer import train
import random
import warnings
warnings.filterwarnings('ignore')
from utils import *
def parse_args():
parser = argparse.ArgumentParser(description="Train & Test with configurable args")
parser.add_argument("--num_train_data", type=int, default=10000, help="number of training samples")
parser.add_argument("--batch_size", type=int, default=64, help="batch size")
parser.add_argument("--epochs", type=int, default=50, help="number of epochs")
parser.add_argument("--group", type=str, default="C8", help="group name for RA model (C4 or C8)")
parser.add_argument("--n_seeds", type=int, default=10, help="number of seeds to run")
return parser.parse_args()
def train_and_test(model, train_loader, val_loader, test_loader, epochs, device):
score, f1 = train(
model=model,
epochs=epochs,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
lr=1e-4,
wd=1e-4,
device=device,
)
return score, f1
def main(args):
device = "cuda" if torch.cuda.is_available() else "cpu"
base_seed = random.randint(1, 999)
accs = []
f1s = []
print_run_config(args, device, base_seed)
for seed in range(base_seed, base_seed + args.n_seeds):
train_dataset = MnistDataset(num_train_data=args.num_train_data, mode="train", seed=seed)
validation_dataset = MnistDataset(mode="validation", seed=seed)
test_dataset = MnistDataset(mode="test", seed=seed)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size)
val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.batch_size)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)
model = RA(group=args.group).to(device)
score, f1 = train_and_test(
model=model,
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
epochs=args.epochs,
device=device,
)
accs.append(score)
f1s.append(f1)
a_m, a_std = cal_mean_std(accs)
print(f"[Acc] Mean: {a_m} | Std: {a_std}\n")
f_m, f_std = cal_mean_std(f1s)
print(f"[F-Score] Mean: {f_m} | Std: {f_std}\n")
if __name__ == "__main__":
args = parse_args()
main(args)