jangwon-kim-cocel commited on
Commit
0b63116
·
verified ·
1 Parent(s): b847249

Upload 6 files

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +71 -0
  3. gif_for_readme.gif +3 -0
  4. main.py +75 -0
  5. model.py +113 -0
  6. trainer.py +82 -0
  7. utils.py +124 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ gif_for_readme.gif filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>Radon Averaging</h1>
3
+ <h3>A Practical Approach for Designing Rotation-Invariant Models</h3>
4
+
5
+ <a href="https://www.python.org/">
6
+ <img src="https://img.shields.io/badge/Python-3.10+-blue?logo=python&style=flat-square" alt="Python Badge"/>
7
+ </a>
8
+ &nbsp;&nbsp;
9
+ <a href="https://pytorch.org/">
10
+ <img src="https://img.shields.io/badge/PyTorch-2.0+-EE4C2C?logo=pytorch&style=flat-square" alt="PyTorch Badge"/>
11
+ </a>
12
+ &nbsp;&nbsp;
13
+ <a href="https://doi.org/10.1016/j.engappai.2025.113299">
14
+ <img src="https://img.shields.io/badge/EAAI%202026-Published-success?style=flat-square" alt="EAAI Badge"/>
15
+ </a>
16
+ &nbsp;&nbsp;
17
+ <a href="https://www.elsevier.com/">
18
+ <img src="https://img.shields.io/badge/Elsevier-Journal-orange?style=flat-square" alt="Elsevier Badge"/>
19
+ </a>
20
+ <br/><br/>
21
+
22
+ <!-- Radon Transform Animation -->
23
+ <img src="./gif_for_readme.gif" width="700px"/>
24
+ </div>
25
+
26
+ ---
27
+
28
+ ## Engineering Applications of Artificial Intelligence (EAAI 2026)
29
+ ### Pytorch Implementation
30
+
31
+ This repository contains a pytorch implementation of **Radon Averaging (RA)** from the paper:
32
+
33
+ > **Radon Averaging: A practical approach for designing rotation-invariant models**
34
+ > Jangwon Kim, Sanghyun Ryoo, Jiwon Kim, Junkee Hong, Soohee Han
35
+ > *Engineering Applications of Artificial Intelligence*, Volume 164, 2026
36
+
37
+ ## 📄 Paper Link
38
+ > **DOI:** https://doi.org/10.1016/j.engappai.2025.113299
39
+ > **Journal:** Engineering Applications of Artificial Intelligence
40
+
41
+ ---
42
+
43
+ ## Radon Averaging
44
+
45
+ Radon Averaging achieves rotation invariance by:
46
+ 1. **Radon Transform** (ℛ): Converts images ($I$) to sinograms, where an rotation corresponds ($$g$$) to a circular shift.
47
+ 2. **Averaging over Discrete Rotations** ($$G$$): Eliminates boundary artifacts via group averaging
48
+ 3. **Standard CNN Backbone** ($$Φ$$): No architectural changes required
49
+ ```math
50
+ RA_G^Φ(I) = \frac{1}{|G|} \sum_{g \in G} (Φ \circ π(g) \circ ℛ)(I)
51
+ ```
52
+ ---
53
+
54
+ ## Advantages
55
+ - **Plug-and-play**: works with standard (pretrained) CNN backbones (no architectural changes).
56
+ - **Rotation invariance in practice**: stable representations under image rotations.
57
+ - **Reduces boundary artifacts**: group averaging mitigates Radon transform edge effects.
58
+ ---
59
+
60
+ ## Citation Example
61
+ ```
62
+ @article{kim2026radonaveraging,
63
+ title = {Radon Averaging: A practical approach for designing rotation-invariant models},
64
+ author = {Kim, Jangwon and Ryoo, Sanghyun and Kim, Jiwon and Hong, Junkee and Han, Soohee},
65
+ journal = {Engineering Applications of Artificial Intelligence},
66
+ volume = {164},
67
+ pages = {113299},
68
+ year = {2026},
69
+ doi = {10.1016/j.engappai.2025.113299}
70
+ }
71
+ ```
gif_for_readme.gif ADDED

Git LFS Details

  • SHA256: 9282f945446d00ba50d01c0d287a4d7b1bbf398bdd468ddbad17094ad5907c93
  • Pointer size: 132 Bytes
  • Size of remote file: 4.55 MB
main.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ from model import RA
4
+ from trainer import train
5
+ import random
6
+ import warnings
7
+ warnings.filterwarnings('ignore')
8
+
9
+ from utils import *
10
+
11
+ def parse_args():
12
+ parser = argparse.ArgumentParser(description="Train & Test with configurable args")
13
+ parser.add_argument("--num_train_data", type=int, default=10000, help="number of training samples")
14
+ parser.add_argument("--batch_size", type=int, default=64, help="batch size")
15
+ parser.add_argument("--epochs", type=int, default=50, help="number of epochs")
16
+ parser.add_argument("--group", type=str, default="C8", help="group name for RA model (C4 or C8)")
17
+ parser.add_argument("--n_seeds", type=int, default=10, help="number of seeds to run")
18
+
19
+ return parser.parse_args()
20
+
21
+
22
+ def train_and_test(model, train_loader, val_loader, test_loader, epochs, device):
23
+ score, f1 = train(
24
+ model=model,
25
+ epochs=epochs,
26
+ train_loader=train_loader,
27
+ val_loader=val_loader,
28
+ test_loader=test_loader,
29
+ lr=1e-4,
30
+ wd=1e-4,
31
+ device=device,
32
+ )
33
+ return score, f1
34
+
35
+
36
+ def main(args):
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ base_seed = random.randint(1, 999)
39
+
40
+ accs = []
41
+ f1s = []
42
+
43
+ print_run_config(args, device, base_seed)
44
+
45
+ for seed in range(base_seed, base_seed + args.n_seeds):
46
+ train_dataset = MnistDataset(num_train_data=args.num_train_data, mode="train", seed=seed)
47
+ validation_dataset = MnistDataset(mode="validation", seed=seed)
48
+ test_dataset = MnistDataset(mode="test", seed=seed)
49
+
50
+ train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size)
51
+ val_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=args.batch_size)
52
+ test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size)
53
+
54
+ model = RA(group=args.group).to(device)
55
+ score, f1 = train_and_test(
56
+ model=model,
57
+ train_loader=train_loader,
58
+ val_loader=val_loader,
59
+ test_loader=test_loader,
60
+ epochs=args.epochs,
61
+ device=device,
62
+ )
63
+ accs.append(score)
64
+ f1s.append(f1)
65
+
66
+ a_m, a_std = cal_mean_std(accs)
67
+ print(f"[Acc] Mean: {a_m} | Std: {a_std}\n")
68
+
69
+ f_m, f_std = cal_mean_std(f1s)
70
+ print(f"[F-Score] Mean: {f_m} | Std: {f_std}\n")
71
+
72
+
73
+ if __name__ == "__main__":
74
+ args = parse_args()
75
+ main(args)
model.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import copy
3
+ import torch
4
+ import torchvision.transforms.functional as TF
5
+ from torchvision import models
6
+ from torchvision.transforms import Pad
7
+ from torchvision.transforms import Resize
8
+ from torchvision.transforms import ToTensor
9
+ from skimage.transform import radon
10
+ from PIL import Image
11
+ import numpy as np
12
+
13
+
14
+ WIDTH = 29
15
+ THETA = np.linspace(0.0, 360, WIDTH, endpoint=False)
16
+
17
+
18
+ class RA(torch.nn.Module):
19
+ def __init__(self, group='C4'):
20
+ super(RA, self).__init__()
21
+ self.backbone = MnistCNN()
22
+ self.group = group
23
+ self.resize1 = Resize(87) # to upsample
24
+ self.resize2 = Resize(29) # to downsample
25
+ self.totensor = ToTensor()
26
+
27
+ def forward(self, x):
28
+ y = 0
29
+ x = x.to('cpu')
30
+ org_x = copy.deepcopy(x)
31
+
32
+ if self.group == 'C1':
33
+ G = [0]
34
+ elif self.group == 'C4':
35
+ G = [0, 90, 180, 270]
36
+ elif self.group == 'C8':
37
+ G = [0, 45, 90, 135, 180, 225, 270, 315]
38
+ else:
39
+ raise NameError
40
+
41
+ if self.group != 'C1':
42
+ for r in G:
43
+ x = x.to('cpu')
44
+ for i in range(x.shape[0]):
45
+ img = org_x[i][0].numpy()
46
+ img = Image.fromarray(img, mode='F')
47
+ np_x = self.totensor(self.resize2(self.resize1(img).rotate(r, Image.BILINEAR))).numpy()
48
+ sinogram = radon(np_x[0], theta=THETA)
49
+ sinogram = torch.FloatTensor(sinogram).reshape(1, 29, 29)
50
+ x[i] = sinogram
51
+
52
+ x = x.to('cuda')
53
+ y += self.backbone(x)
54
+
55
+ y /= len(G)
56
+ return y
57
+ else:
58
+ x = x.to('cpu')
59
+ for i in range(x.shape[0]):
60
+ np_x = org_x[i][0].numpy()
61
+ sinogram = radon(np_x, theta=THETA)
62
+ sinogram = torch.FloatTensor(sinogram).reshape(1, 29, 29)
63
+ x[i] = sinogram
64
+
65
+ x = x.to('cuda')
66
+ y = self.backbone(x)
67
+ return y
68
+
69
+
70
+ class MnistCNN(torch.nn.Module):
71
+ def __init__(self):
72
+ super(MnistCNN, self).__init__()
73
+ self.keep_prob = 0.5
74
+
75
+ self.layer1 = torch.nn.Sequential(
76
+ torch.nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
77
+ torch.nn.ReLU(),
78
+ torch.nn.MaxPool2d(kernel_size=2, stride=2))
79
+
80
+ self.layer2 = torch.nn.Sequential(
81
+ torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
82
+ torch.nn.ReLU(),
83
+ torch.nn.MaxPool2d(kernel_size=2, stride=2))
84
+
85
+ self.layer3 = torch.nn.Sequential(
86
+ torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
87
+ torch.nn.ReLU(),
88
+ torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=1))
89
+
90
+ self.fc1 = torch.nn.Linear(2048, 128, bias=True)
91
+ self.layer4 = torch.nn.Sequential(
92
+ self.fc1,
93
+ torch.nn.ReLU(),
94
+ torch.nn.Dropout(p=1 - self.keep_prob))
95
+
96
+ self.fc2 = torch.nn.Linear(128, 10, bias=True)
97
+ self._initialize_weights()
98
+
99
+ def _initialize_weights(self):
100
+ for m in self.modules():
101
+ if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
102
+ torch.nn.init.orthogonal_(m.weight)
103
+ if m.bias is not None:
104
+ torch.nn.init.constant_(m.bias, 0)
105
+
106
+ def forward(self, x):
107
+ out = self.layer1(x)
108
+ out = self.layer2(out)
109
+ out = self.layer3(out)
110
+ out = out.view(out.size(0), -1)
111
+ out = self.layer4(out)
112
+ out = self.fc2(out)
113
+ return out
trainer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from tqdm.auto import tqdm
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ from torch.utils.data import Dataset
9
+ from torchvision.transforms import Pad
10
+ from torchvision.transforms import Resize
11
+ from torchvision.transforms import ToTensor
12
+ from skimage.transform import radon, rescale
13
+ from scipy.ndimage import rotate
14
+ import random
15
+ from sklearn.metrics import f1_score
16
+
17
+
18
+ def test(model, test_loader, device, final_test=False):
19
+ # test over the full rotated test set
20
+ total = 0
21
+ correct = 0
22
+
23
+ all_predictions = []
24
+ all_targets = []
25
+
26
+ with torch.no_grad():
27
+ model.eval()
28
+ for i, (x, t) in enumerate(test_loader):
29
+ x = x.to(device)
30
+ t = t.to(device)
31
+ y = model(x)
32
+ y = y.view(-1, 10)
33
+
34
+ _, prediction = torch.max(y.data, 1)
35
+ total += t.shape[0]
36
+ correct += (prediction == t).sum().item()
37
+
38
+ # Collect predictions and targets for F1 score calculation
39
+ all_predictions.extend(prediction.cpu().numpy())
40
+ all_targets.extend(t.cpu().numpy())
41
+
42
+ f1 = f1_score(all_targets, all_predictions, average='weighted')
43
+ if final_test:
44
+ print(f"[Final Test] Acc: {correct/total*100.} | F1-Score: {f1}\n")
45
+
46
+ return correct/total*100.0, f1
47
+
48
+
49
+ def train(model, epochs, train_loader, val_loader, test_loader, lr=1e-4, wd=1e-4, device='cuda'):
50
+ best_acc = 0.0
51
+ best_model = None
52
+ loss_function = torch.nn.CrossEntropyLoss()
53
+ optimizer = torch.optim.Adam(model.backbone.parameters(), lr=lr, weight_decay=wd)
54
+
55
+ for epoch in range(epochs):
56
+ model.train()
57
+ for i, (x, t) in enumerate(train_loader):
58
+ optimizer.zero_grad()
59
+ x = x.to(device)
60
+ t = t.to(device)
61
+ y = model(x)
62
+ y = y.view(-1, 10)
63
+ loss = loss_function(y, t)
64
+
65
+ loss.backward()
66
+
67
+ optimizer.step()
68
+ del x, y, t, loss
69
+ if (epoch + 1) % 1 == 0:
70
+ accuracy, _ = test(model, val_loader, device=device)
71
+ print(f"epoch {epoch + 1} | validation accuracy: {accuracy}")
72
+ if accuracy > best_acc:
73
+ best_acc = accuracy
74
+ best_model = copy.deepcopy(model.to('cpu'))
75
+ model = model.to('cuda')
76
+
77
+ print(f"Max validation accuracy: {best_acc}\n")
78
+ best_model = best_model.to('cuda')
79
+ score, f1 = test(best_model, test_loader, device=device, final_test=True)
80
+ del best_model
81
+ del model
82
+ return score, f1
utils.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import shutil
3
+ import torch
4
+ import numpy as np
5
+ import tempfile
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+ from torchvision.transforms import Pad, Resize, ToTensor
9
+ from skimage.transform import radon, rescale
10
+ from scipy.ndimage import rotate
11
+
12
+
13
+ MNIST_TRAIN_GDRIVE = "https://drive.google.com/file/d/15G2FsYGRSpEkr5MTVofSFhKaMMIeiFhk/view?usp=drive_link"
14
+ MNIST_TEST_GDRIVE = "https://drive.google.com/file/d/1PK1DeFpw2OomuHDoA8ZtTPWLkPHOT6u6/view?usp=drive_link"
15
+
16
+ def _gdrive_download(url_or_id: str, out_path: Path) -> None:
17
+ import gdown
18
+
19
+ out_path.parent.mkdir(parents=True, exist_ok=True)
20
+
21
+ with tempfile.NamedTemporaryFile(delete=False, dir=str(out_path.parent), suffix=".tmp") as tf:
22
+ tmp_path = Path(tf.name)
23
+
24
+ try:
25
+ ok = gdown.download(url_or_id, str(tmp_path), quiet=False, fuzzy=True)
26
+ if not ok or not tmp_path.exists() or tmp_path.stat().st_size == 0:
27
+ raise RuntimeError(f"[Google Drive] Failed to download: {url_or_id}")
28
+
29
+ shutil.move(str(tmp_path), str(out_path))
30
+ finally:
31
+ if tmp_path.exists():
32
+ try:
33
+ tmp_path.unlink()
34
+ except Exception:
35
+ pass
36
+
37
+
38
+ def _ensure_mnist_amat_files(mnist_dir: str = "mnist") -> None:
39
+ mnist_dir = Path(mnist_dir)
40
+ train_path = mnist_dir / "mnist_train.amat"
41
+ test_path = mnist_dir / "mnist_test.amat"
42
+
43
+ if not train_path.is_file():
44
+ print(f"[MNIST] '{train_path}' not found. Downloading MNIST train file from Google Drive...")
45
+ _gdrive_download(MNIST_TRAIN_GDRIVE, train_path)
46
+ print(f"[MNIST] Download complete: {train_path}")
47
+
48
+ if not test_path.is_file():
49
+ print(f"[MNIST] '{test_path}' not found. Downloading MNIST test file from Google Drive...")
50
+ _gdrive_download(MNIST_TEST_GDRIVE, test_path)
51
+ print(f"[MNIST] Download complete: {test_path}")
52
+
53
+
54
+ class MnistDataset(Dataset):
55
+ def __init__(self, mode, num_train_data=10000, seed=1):
56
+ assert num_train_data + 12000 <= 45000
57
+ assert mode in ['train', 'validation', 'test']
58
+ if seed is not None:
59
+ np.random.seed(seed)
60
+ torch.manual_seed(seed)
61
+
62
+ _ensure_mnist_amat_files("mnist")
63
+
64
+ if mode == "test":
65
+ file = "mnist/mnist_test.amat"
66
+ else:
67
+ file = "mnist/mnist_train.amat"
68
+
69
+ data = np.loadtxt(file)
70
+ images = data[:, :-1].reshape(-1, 28, 28).astype(np.float32)
71
+
72
+ # Shuffle the images
73
+ indices = np.arange(num_train_data + 12000)
74
+ np.random.shuffle(indices)
75
+ if mode == 'train':
76
+ images = images[indices[:num_train_data]]
77
+ data = data[indices[:num_train_data]]
78
+ elif mode == 'validation':
79
+ images = images[indices[num_train_data:]]
80
+ data = data[indices[num_train_data:]]
81
+
82
+ if mode == 'test' or mode == 'validation':
83
+ pad = Pad((0, 0, 1, 1), fill=0)
84
+ resize1 = Resize(87) # to upsample
85
+ resize2 = Resize(29) # to downsample
86
+ totensor = ToTensor()
87
+
88
+ self.images = torch.empty((images.shape[0], 1, 29, 29))
89
+ for i in range(images.shape[0]):
90
+ img = images[i]
91
+ img = Image.fromarray(img, mode='F')
92
+ r = (np.random.rand() * 360.)
93
+ self.images[i] = totensor(resize2(resize1(pad(img)).rotate(r, Image.BILINEAR))).reshape(1, 29, 29)
94
+ else:
95
+ self.images = torch.zeros((images.shape[0], 1, 29, 29))
96
+ self.images[:, :, :28, :28] = torch.tensor(images).reshape(-1, 1, 28, 28)
97
+
98
+ self.labels = data[:, -1].astype(np.int64)
99
+
100
+ def __getitem__(self, index):
101
+ image, label = self.images[index], self.labels[index]
102
+
103
+ return image, label
104
+
105
+ def __len__(self):
106
+ return len(self.labels)
107
+
108
+
109
+ def cal_mean_std(ary):
110
+ ary_np = np.array(ary)
111
+ return np.mean(ary_np), np.std(ary_np)
112
+
113
+ def print_run_config(args, device, base_seed):
114
+ print("\n" + "=" * 34)
115
+ print("RA Training Configuration")
116
+ print("-" * 34)
117
+ print(f"Device : {device}")
118
+ print(f"Group : {args.group}")
119
+ print(f"Epochs : {args.epochs}")
120
+ print(f"Batch size : {args.batch_size}")
121
+ print(f"Train samples : {args.num_train_data}")
122
+ print(f"Num seeds : {args.n_seeds}")
123
+ print(f"Seed range : {base_seed} ~ {base_seed + args.n_seeds - 1}")
124
+ print("=" * 34 + "\n")